mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
support mtp in hybird-dp-tp mode (#4299)
This commit is contained in:
@@ -18,7 +18,7 @@ import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
||||
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer):
|
||||
self.bias_key = prefix + ".bias"
|
||||
else:
|
||||
self.bias_key = None
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.fd_config = fd_config
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.column_cut = True
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
)
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.bias_key is not None:
|
||||
set_weight_attrs(self.linear.bias, {"output_dim": True})
|
||||
set_weight_attrs(
|
||||
self.linear.bias,
|
||||
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
set_weight_attrs(
|
||||
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
@@ -100,9 +117,6 @@ class ParallelEHProjection(nn.Layer):
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||
else:
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
@@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer):
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = self.linear(logits)
|
||||
return logits
|
||||
|
@@ -147,6 +147,10 @@ class MTPProposer(Proposer):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not self.parallel_config.do_profile and (
|
||||
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
):
|
||||
@@ -156,8 +160,8 @@ class MTPProposer(Proposer):
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(key_cache)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
@@ -177,6 +181,17 @@ class MTPProposer(Proposer):
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
del value
|
||||
@@ -610,7 +625,7 @@ class MTPProposer(Proposer):
|
||||
self.max_model_len,
|
||||
self.model_inputs["substep"],
|
||||
)
|
||||
if self.role == "prefill":
|
||||
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
|
||||
mtp_save_first_token(
|
||||
self.model_inputs["base_model_draft_tokens"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
@@ -697,7 +712,11 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(
|
||||
sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
|
||||
|
Reference in New Issue
Block a user