mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
support mtp in hybird-dp-tp mode (#4299)
This commit is contained in:
@@ -147,6 +147,10 @@ class MTPProposer(Proposer):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not self.parallel_config.do_profile and (
|
||||
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
):
|
||||
@@ -156,8 +160,8 @@ class MTPProposer(Proposer):
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(key_cache)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
@@ -177,6 +181,17 @@ class MTPProposer(Proposer):
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
del value
|
||||
@@ -610,7 +625,7 @@ class MTPProposer(Proposer):
|
||||
self.max_model_len,
|
||||
self.model_inputs["substep"],
|
||||
)
|
||||
if self.role == "prefill":
|
||||
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
|
||||
mtp_save_first_token(
|
||||
self.model_inputs["base_model_draft_tokens"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
@@ -697,7 +712,11 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(
|
||||
sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
|
||||
|
Reference in New Issue
Block a user