mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00

* Support DP+TP+EP hybrid parallel deployment strategy * Support DP+TP+EP hybrid parallel deployment strategy * fix conflict * add moe_tp_ep function split_allgather_out * del tp_group in moe_cutlass_backend * for ci * fix parallel_config for ci * del log
136 lines
5.0 KiB
Python
136 lines
5.0 KiB
Python
"""
|
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
import paddle
|
|
from paddle import nn
|
|
from paddle.distributed import fleet
|
|
|
|
from fastdeploy.config import FDConfig
|
|
from fastdeploy.model_executor.utils import set_weight_attrs
|
|
|
|
from .utils import get_tensor
|
|
|
|
|
|
class VocabParallelEmbedding(nn.Layer):
|
|
"""
|
|
VocabParallelEmbedding Layer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fd_config: FDConfig,
|
|
num_embeddings: int,
|
|
embedding_dim: int = 768,
|
|
params_dtype: str = "bfloat16",
|
|
prefix="",
|
|
) -> None:
|
|
"""
|
|
Initialize the VocabParallelEmbedding layer for the model.
|
|
|
|
Args:
|
|
fd_config (FDConfig): Arguments related to inference, containing
|
|
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
|
num_attention_heads, and ffn_hidden_size.
|
|
num_embeddings (int) : vocabulary size.
|
|
embedding_dim (int) : size of hidden state.
|
|
params_dtype (str) : data type of parameters.
|
|
prefix (str): The name of current layer. Defaults to "".
|
|
"""
|
|
super().__init__()
|
|
self.fd_config = fd_config
|
|
hcg = fleet.get_hybrid_communicate_group()
|
|
self.mp_rank: int = hcg.get_model_parallel_rank()
|
|
self.column_cut = False
|
|
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
|
|
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
|
self.tp_group = fd_config.parallel_config.tp_group
|
|
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
|
|
self.initializer_range: float = fd_config.model_config.initializer_range
|
|
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
|
|
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
|
self.params_dtype: str = params_dtype
|
|
|
|
if not self.column_cut:
|
|
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
|
num_embeddings,
|
|
embedding_dim,
|
|
mp_group=self.tp_group,
|
|
weight_attr=paddle.ParamAttr(
|
|
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
|
),
|
|
)
|
|
if self.world_size > 1:
|
|
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
|
else:
|
|
# column cut embedding
|
|
self.embeddings = nn.Embedding(
|
|
num_embeddings,
|
|
embedding_dim // self.world_size,
|
|
)
|
|
|
|
self.embeddings.weight.is_distributed = True
|
|
self.embeddings.weight.split_axis = 1
|
|
if self.world_size > 1:
|
|
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
|
|
|
self.prefix = prefix
|
|
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
|
"""
|
|
Load the checkpoint state dictionary into the layer.
|
|
|
|
Args:
|
|
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
|
"""
|
|
if self.tie_word_embeddings:
|
|
self.embeddings.weight.set_value(
|
|
get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
|
|
)
|
|
else:
|
|
self.embeddings.weight.set_value(
|
|
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
|
|
)
|
|
|
|
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
|
"""
|
|
Defines the forward computation of the layer.
|
|
|
|
Args:
|
|
ids_remove_padding (Tensor, optional): Tensor of token IDs, with padding removed.
|
|
If None, no input is provided.
|
|
|
|
Returns:
|
|
Tensor: Embedded tensor representation of the input IDs.
|
|
"""
|
|
if self.column_cut:
|
|
input_embedings = self.embeddings(ids_remove_padding)
|
|
inputs_embeds_temp = []
|
|
paddle.distributed.all_gather(
|
|
inputs_embeds_temp,
|
|
input_embedings,
|
|
group=self.tp_group,
|
|
sync_op=True,
|
|
)
|
|
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
|
else:
|
|
input_embedings = self.embeddings(ids_remove_padding)
|
|
|
|
return input_embedings
|