mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-17 14:11:14 +08:00
[Feature] Support mtp ep in fd (#3340)
* [Optimize] Add metrics for analysing perf * Fix bug in mtp
This commit is contained in:
@@ -866,10 +866,10 @@ class LLMEngine:
|
|||||||
is_prefill = True
|
is_prefill = True
|
||||||
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
|
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
task.inference_start_time = time.time()
|
||||||
if not is_decode:
|
if not is_decode:
|
||||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||||
for task in tasks:
|
|
||||||
task.inference_start_time = time.time()
|
|
||||||
if not self.cfg.enable_mm:
|
if not self.cfg.enable_mm:
|
||||||
self.update_requests_chunk_size(tasks)
|
self.update_requests_chunk_size(tasks)
|
||||||
else:
|
else:
|
||||||
|
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
@@ -293,6 +294,9 @@ class ExpertService:
|
|||||||
cur_task_idx = self.resource_manager.req_dict[task.request_id]
|
cur_task_idx = self.resource_manager.req_dict[task.request_id]
|
||||||
del self.resource_manager.req_dict[task.request_id]
|
del self.resource_manager.req_dict[task.request_id]
|
||||||
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
||||||
|
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||||
|
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
|
||||||
|
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
|
||||||
if task.error_code != 200:
|
if task.error_code != 200:
|
||||||
self.resource_manager.stop_flags[cur_task_idx] = True
|
self.resource_manager.stop_flags[cur_task_idx] = True
|
||||||
self.resource_manager.tasks_list[cur_task_idx] = None
|
self.resource_manager.tasks_list[cur_task_idx] = None
|
||||||
|
@@ -195,7 +195,14 @@ class TokenProcessor:
|
|||||||
try:
|
try:
|
||||||
is_blocking = True
|
is_blocking = True
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
if (
|
||||||
|
self.cfg.parallel_config.enable_expert_parallel
|
||||||
|
and self.cfg.parallel_config.data_parallel_size > 1
|
||||||
|
):
|
||||||
|
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
|
||||||
|
else:
|
||||||
|
|
||||||
|
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||||
if self.output_tokens[0] == -2:
|
if self.output_tokens[0] == -2:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -478,6 +485,7 @@ class TokenProcessor:
|
|||||||
arrival_time=task.arrival_time,
|
arrival_time=task.arrival_time,
|
||||||
inference_start_time=task.inference_start_time,
|
inference_start_time=task.inference_start_time,
|
||||||
first_token_time=time.time() - task.inference_start_time,
|
first_token_time=time.time() - task.inference_start_time,
|
||||||
|
model_execute_time=time.time() - task.inference_start_time,
|
||||||
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
||||||
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
||||||
request_start_time=task.arrival_time,
|
request_start_time=task.arrival_time,
|
||||||
@@ -489,6 +497,7 @@ class TokenProcessor:
|
|||||||
metrics = RequestMetrics(
|
metrics = RequestMetrics(
|
||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
request_start_time=task.arrival_time,
|
request_start_time=task.arrival_time,
|
||||||
|
model_execute_time=time.time() - task.inference_start_time,
|
||||||
)
|
)
|
||||||
self.number_of_output_tokens += len(token_ids)
|
self.number_of_output_tokens += len(token_ids)
|
||||||
self._record_metrics(task, current_time, token_ids)
|
self._record_metrics(task, current_time, token_ids)
|
||||||
@@ -506,7 +515,7 @@ class TokenProcessor:
|
|||||||
if self.tokens_counter[task_id] == 0:
|
if self.tokens_counter[task_id] == 0:
|
||||||
if task.messages is not None:
|
if task.messages is not None:
|
||||||
result.prompt = task.messages
|
result.prompt = task.messages
|
||||||
result.num_cached_tokens = task.num_cached_tokens
|
result.num_cached_tokens = task.num_cached_tokens
|
||||||
|
|
||||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||||
|
|
||||||
|
@@ -76,6 +76,7 @@ class MTPProposer(Proposer):
|
|||||||
self.model_config.num_hidden_layers = 1
|
self.model_config.num_hidden_layers = 1
|
||||||
self.model_config.model = self.speculative_config.model
|
self.model_config.model = self.speculative_config.model
|
||||||
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
|
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
|
||||||
|
self.model_config.is_quantized = False
|
||||||
if self.speculative_config.quantization != "":
|
if self.speculative_config.quantization != "":
|
||||||
self.model_config.quantization = self.speculative_config.quantization
|
self.model_config.quantization = self.speculative_config.quantization
|
||||||
self.model_config.start_layer_index = self.num_main_model_layers
|
self.model_config.start_layer_index = self.num_main_model_layers
|
||||||
@@ -142,6 +143,7 @@ class MTPProposer(Proposer):
|
|||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||||
)
|
)
|
||||||
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed":
|
if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed":
|
||||||
cache_kvs_list = []
|
cache_kvs_list = []
|
||||||
for i in range(
|
for i in range(
|
||||||
@@ -149,8 +151,8 @@ class MTPProposer(Proposer):
|
|||||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||||
cache_kvs_list.append(key_cache)
|
cache_kvs_list.append(key_cache)
|
||||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
@@ -176,11 +178,11 @@ class MTPProposer(Proposer):
|
|||||||
if self.cache_config.enable_prefix_caching:
|
if self.cache_config.enable_prefix_caching:
|
||||||
set_data_ipc(
|
set_data_ipc(
|
||||||
self.cache_kvs[f"key_caches_{i}"],
|
self.cache_kvs[f"key_caches_{i}"],
|
||||||
f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}",
|
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||||
)
|
)
|
||||||
set_data_ipc(
|
set_data_ipc(
|
||||||
self.cache_kvs[f"value_caches_{i}"],
|
self.cache_kvs[f"value_caches_{i}"],
|
||||||
f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}",
|
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||||
)
|
)
|
||||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||||
for value in self.cache_kvs.values():
|
for value in self.cache_kvs.values():
|
||||||
|
@@ -503,6 +503,7 @@ class SplitwiseConnector:
|
|||||||
index=task["outputs"]["index"],
|
index=task["outputs"]["index"],
|
||||||
send_idx=0,
|
send_idx=0,
|
||||||
token_ids=task["outputs"]["token_ids"],
|
token_ids=task["outputs"]["token_ids"],
|
||||||
|
draft_token_ids=task["outputs"]["draft_token_ids"],
|
||||||
),
|
),
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user