mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-18 06:31:17 +08:00
[Feature] Support mtp ep in fd (#3340)
* [Optimize] Add metrics for analysing perf * Fix bug in mtp
This commit is contained in:
@@ -76,6 +76,7 @@ class MTPProposer(Proposer):
|
||||
self.model_config.num_hidden_layers = 1
|
||||
self.model_config.model = self.speculative_config.model
|
||||
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
|
||||
self.model_config.is_quantized = False
|
||||
if self.speculative_config.quantization != "":
|
||||
self.model_config.quantization = self.speculative_config.quantization
|
||||
self.model_config.start_layer_index = self.num_main_model_layers
|
||||
@@ -142,6 +143,7 @@ class MTPProposer(Proposer):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed":
|
||||
cache_kvs_list = []
|
||||
for i in range(
|
||||
@@ -149,8 +151,8 @@ class MTPProposer(Proposer):
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(key_cache)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
@@ -176,11 +178,11 @@ class MTPProposer(Proposer):
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"key_caches_{i}"],
|
||||
f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}",
|
||||
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"value_caches_{i}"],
|
||||
f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}",
|
||||
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
|
Reference in New Issue
Block a user