support mtp in hybird-dp-tp mode (#4299)

This commit is contained in:
freeliuzc
2025-09-28 15:58:45 +08:00
committed by GitHub
parent 076c30cb0f
commit c8985727a6
2 changed files with 79 additions and 49 deletions

View File

@@ -18,7 +18,7 @@ import paddle
from paddle import nn from paddle import nn
from paddle.distributed import fleet from paddle.distributed import fleet
from fastdeploy.model_executor.utils import set_weight_attrs from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
from .utils import get_tensor from .utils import get_tensor
@@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer):
self.bias_key = prefix + ".bias" self.bias_key = prefix + ".bias"
else: else:
self.bias_key = None self.bias_key = None
self.use_ep = fd_config.parallel_config.use_ep self.fd_config = fd_config
self.tp_group = fd_config.parallel_config.tp_group
self.column_cut = True self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if self.use_ep:
self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=False,
)
else:
if self.column_cut: if self.column_cut:
need_gather = True need_gather = True
self.linear = ColumnParallelLinear( self.linear = ColumnParallelLinear(
embedding_dim, embedding_dim,
num_embeddings, num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), mp_group=self.tp_group,
weight_attr=None, weight_attr=None,
has_bias=True if self.bias_key is not None else False, has_bias=True if self.bias_key is not None else False,
gather_output=need_gather, gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False, # False diff更小
) )
set_weight_attrs(self.linear.weight, {"output_dim": True}) set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.bias_key is not None: if self.bias_key is not None:
set_weight_attrs(self.linear.bias, {"output_dim": True}) set_weight_attrs(
self.linear.bias,
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
else: else:
self.linear = RowParallelLinear( self.linear = RowParallelLinear(
embedding_dim, embedding_dim,
num_embeddings, num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), mp_group=self.tp_group,
weight_attr=None, weight_attr=None,
has_bias=True if self.bias_key is not None else False, has_bias=True if self.bias_key is not None else False,
input_is_parallel=False, input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False, # False diff更小
) )
set_weight_attrs(self.linear.weight, {"output_dim": False}) set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
set_weight_attrs(
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """
@@ -100,9 +117,6 @@ class ParallelEHProjection(nn.Layer):
state_dict (dict): A dictionary containing the checkpoint weights and biases. state_dict (dict): A dictionary containing the checkpoint weights and biases.
""" """
if self.use_ep:
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
if self.linear.weight.shape != weight_tensor.shape: if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0]) weight_tensor = weight_tensor.transpose([1, 0])
@@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer):
Tensor: The output tensor after processing through the layer. Tensor: The output tensor after processing through the layer.
""" """
logits = input logits = input
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.linear(logits) logits = self.linear(logits)
return logits return logits

View File

@@ -147,6 +147,10 @@ class MTPProposer(Proposer):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
) )
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not self.parallel_config.do_profile and ( if not self.parallel_config.do_profile and (
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
): ):
@@ -156,8 +160,8 @@ class MTPProposer(Proposer):
self.num_main_model_layers + self.model_config.num_hidden_layers, self.num_main_model_layers + self.model_config.num_hidden_layers,
): ):
key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}" key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}" val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
cache_kvs_list.append(key_cache) cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type) value_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -177,6 +181,17 @@ class MTPProposer(Proposer):
fill_value=0, fill_value=0,
dtype=cache_type, dtype=cache_type,
) )
if kv_cache_quant_type == "block_wise_fp8":
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.model_inputs["caches"] = list(self.cache_kvs.values()) self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values(): for value in self.cache_kvs.values():
del value del value
@@ -610,7 +625,7 @@ class MTPProposer(Proposer):
self.max_model_len, self.max_model_len,
self.model_inputs["substep"], self.model_inputs["substep"],
) )
if self.role == "prefill": if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
mtp_save_first_token( mtp_save_first_token(
self.model_inputs["base_model_draft_tokens"], self.model_inputs["base_model_draft_tokens"],
self.model_inputs["not_need_stop"], self.model_inputs["not_need_stop"],
@@ -697,7 +712,11 @@ class MTPProposer(Proposer):
) )
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(sampled_token_ids, 0) paddle.distributed.broadcast(
sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
self._post_process(sampled_token_ids) self._post_process(sampled_token_ids)