mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
support mtp in hybird-dp-tp mode (#4299)
This commit is contained in:
@@ -18,7 +18,7 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.distributed import fleet
|
from paddle.distributed import fleet
|
||||||
|
|
||||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
|
||||||
|
|
||||||
from .utils import get_tensor
|
from .utils import get_tensor
|
||||||
|
|
||||||
@@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer):
|
|||||||
self.bias_key = prefix + ".bias"
|
self.bias_key = prefix + ".bias"
|
||||||
else:
|
else:
|
||||||
self.bias_key = None
|
self.bias_key = None
|
||||||
self.use_ep = fd_config.parallel_config.use_ep
|
self.fd_config = fd_config
|
||||||
|
self.tp_group = fd_config.parallel_config.tp_group
|
||||||
self.column_cut = True
|
self.column_cut = True
|
||||||
|
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||||
|
|
||||||
if self.use_ep:
|
if self.column_cut:
|
||||||
self.weight = self.create_parameter(
|
need_gather = True
|
||||||
shape=[embedding_dim, num_embeddings],
|
self.linear = ColumnParallelLinear(
|
||||||
dtype=paddle.get_default_dtype(),
|
embedding_dim,
|
||||||
is_bias=False,
|
num_embeddings,
|
||||||
|
mp_group=self.tp_group,
|
||||||
|
weight_attr=None,
|
||||||
|
has_bias=True if self.bias_key is not None else False,
|
||||||
|
gather_output=need_gather,
|
||||||
|
fuse_matmul_bias=False, # False diff更小
|
||||||
)
|
)
|
||||||
else:
|
set_weight_attrs(
|
||||||
if self.column_cut:
|
self.linear.weight,
|
||||||
need_gather = True
|
{
|
||||||
self.linear = ColumnParallelLinear(
|
"weight_loader": default_weight_loader(self.fd_config),
|
||||||
embedding_dim,
|
"model_format": self.fd_config.model_config.model_format,
|
||||||
num_embeddings,
|
},
|
||||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
)
|
||||||
weight_attr=None,
|
if self.bias_key is not None:
|
||||||
has_bias=True if self.bias_key is not None else False,
|
set_weight_attrs(
|
||||||
gather_output=need_gather,
|
self.linear.bias,
|
||||||
fuse_matmul_bias=False, # False diff更小
|
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
|
||||||
)
|
)
|
||||||
|
if self.nranks > 1:
|
||||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||||
if self.bias_key is not None:
|
else:
|
||||||
set_weight_attrs(self.linear.bias, {"output_dim": True})
|
self.linear = RowParallelLinear(
|
||||||
else:
|
embedding_dim,
|
||||||
self.linear = RowParallelLinear(
|
num_embeddings,
|
||||||
embedding_dim,
|
mp_group=self.tp_group,
|
||||||
num_embeddings,
|
weight_attr=None,
|
||||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
has_bias=True if self.bias_key is not None else False,
|
||||||
weight_attr=None,
|
input_is_parallel=False,
|
||||||
has_bias=True if self.bias_key is not None else False,
|
fuse_matmul_bias=False, # False diff更小
|
||||||
input_is_parallel=False,
|
)
|
||||||
fuse_matmul_bias=False, # False diff更小
|
set_weight_attrs(
|
||||||
)
|
self.linear.weight,
|
||||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
{
|
||||||
|
"weight_loader": default_weight_loader(self.fd_config),
|
||||||
|
"model_format": self.fd_config.model_config.model_format,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if self.nranks > 1:
|
||||||
|
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||||
|
set_weight_attrs(
|
||||||
|
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||||
|
)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
"""
|
"""
|
||||||
@@ -100,17 +117,14 @@ class ParallelEHProjection(nn.Layer):
|
|||||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.use_ep:
|
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
if self.linear.weight.shape != weight_tensor.shape:
|
||||||
else:
|
weight_tensor = weight_tensor.transpose([1, 0])
|
||||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
self.linear.weight.set_value(weight_tensor)
|
||||||
if self.linear.weight.shape != weight_tensor.shape:
|
|
||||||
weight_tensor = weight_tensor.transpose([1, 0])
|
|
||||||
self.linear.weight.set_value(weight_tensor)
|
|
||||||
|
|
||||||
if self.bias_key is not None:
|
if self.bias_key is not None:
|
||||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||||
self.linear.bias.set_value(bias)
|
self.linear.bias.set_value(bias)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
"""
|
"""
|
||||||
@@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer):
|
|||||||
Tensor: The output tensor after processing through the layer.
|
Tensor: The output tensor after processing through the layer.
|
||||||
"""
|
"""
|
||||||
logits = input
|
logits = input
|
||||||
if self.use_ep:
|
logits = self.linear(logits)
|
||||||
logits = paddle.matmul(logits, self.weight)
|
|
||||||
else:
|
|
||||||
logits = self.linear(logits)
|
|
||||||
return logits
|
return logits
|
||||||
|
@@ -147,6 +147,10 @@ class MTPProposer(Proposer):
|
|||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||||
)
|
)
|
||||||
|
if kv_cache_quant_type == "block_wise_fp8":
|
||||||
|
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||||
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
if not self.parallel_config.do_profile and (
|
if not self.parallel_config.do_profile and (
|
||||||
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||||
):
|
):
|
||||||
@@ -156,8 +160,8 @@ class MTPProposer(Proposer):
|
|||||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||||
cache_kvs_list.append(key_cache)
|
cache_kvs_list.append(key_cache)
|
||||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
@@ -177,6 +181,17 @@ class MTPProposer(Proposer):
|
|||||||
fill_value=0,
|
fill_value=0,
|
||||||
dtype=cache_type,
|
dtype=cache_type,
|
||||||
)
|
)
|
||||||
|
if kv_cache_quant_type == "block_wise_fp8":
|
||||||
|
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
|
||||||
|
shape=kv_cache_scale_shape,
|
||||||
|
fill_value=0,
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
)
|
||||||
|
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
|
||||||
|
shape=kv_cache_scale_shape,
|
||||||
|
fill_value=0,
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
)
|
||||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||||
for value in self.cache_kvs.values():
|
for value in self.cache_kvs.values():
|
||||||
del value
|
del value
|
||||||
@@ -610,7 +625,7 @@ class MTPProposer(Proposer):
|
|||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
self.model_inputs["substep"],
|
self.model_inputs["substep"],
|
||||||
)
|
)
|
||||||
if self.role == "prefill":
|
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
|
||||||
mtp_save_first_token(
|
mtp_save_first_token(
|
||||||
self.model_inputs["base_model_draft_tokens"],
|
self.model_inputs["base_model_draft_tokens"],
|
||||||
self.model_inputs["not_need_stop"],
|
self.model_inputs["not_need_stop"],
|
||||||
@@ -697,7 +712,11 @@ class MTPProposer(Proposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.parallel_config.tensor_parallel_size > 1:
|
if self.parallel_config.tensor_parallel_size > 1:
|
||||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
paddle.distributed.broadcast(
|
||||||
|
sampled_token_ids,
|
||||||
|
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||||
|
group=self.parallel_config.tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
self._post_process(sampled_token_ids)
|
self._post_process(sampled_token_ids)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user