[Feature] Support mtp ep in fd (#3340)

* [Optimize] Add metrics for analysing perf

* Fix bug in mtp
This commit is contained in:
chenjian
2025-08-11 21:49:44 +08:00
committed by GitHub
parent 110f33a530
commit 7573802a88
5 changed files with 24 additions and 8 deletions

View File

@@ -195,7 +195,14 @@ class TokenProcessor:
try:
is_blocking = True
if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
@@ -478,6 +485,7 @@ class TokenProcessor:
arrival_time=task.arrival_time,
inference_start_time=task.inference_start_time,
first_token_time=time.time() - task.inference_start_time,
model_execute_time=time.time() - task.inference_start_time,
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
request_start_time=task.arrival_time,
@@ -489,6 +497,7 @@ class TokenProcessor:
metrics = RequestMetrics(
arrival_time=time.time(),
request_start_time=task.arrival_time,
model_execute_time=time.time() - task.inference_start_time,
)
self.number_of_output_tokens += len(token_ids)
self._record_metrics(task, current_time, token_ids)
@@ -506,7 +515,7 @@ class TokenProcessor:
if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens
result.num_cached_tokens = task.num_cached_tokens
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"