Files
FastDeploy/fastdeploy/model_executor/layers/embeddings.py
bukejiyu 77514e3e1e
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
[V1 Loader] support weight_only (#3413)
* support wint4/wint8

* delete smoe case

* update ci

* print log
2025-08-23 13:13:41 +08:00

145 lines
5.4 KiB
Python

"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Dict
import numpy as np
import paddle
from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import set_weight_attrs
from .utils import get_tensor
class VocabParallelEmbedding(nn.Layer):
"""
VocabParallelEmbedding Layer
"""
def __init__(
self,
fd_config: FDConfig,
num_embeddings: int,
embedding_dim: int = 768,
params_dtype: str = "bfloat16",
prefix="",
) -> None:
"""
Initialize the VocabParallelEmbedding layer for the model.
Args:
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
num_embeddings (int) : vocabulary size.
embedding_dim (int) : size of hidden state.
params_dtype (str) : data type of parameters.
prefix (str): The name of current layer. Defaults to "".
"""
super().__init__()
self.fd_config = fd_config
hcg = fleet.get_hybrid_communicate_group()
self.mp_rank: int = hcg.get_model_parallel_rank()
self.column_cut = False
self.world_size: int = hcg.get_model_parallel_world_size()
self.ring_id: int = hcg.get_model_parallel_group().id
self.use_ep: bool = fd_config.parallel_config.use_ep
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
self.initializer_range: float = fd_config.model_config.initializer_range
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
if self.use_ep:
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim,
)
else:
if not self.column_cut:
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
num_embeddings,
embedding_dim,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
else:
# column cut embedding
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim // self.world_size,
)
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
if self.tie_word_embeddings:
self.embeddings.weight.set_value(
get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
)
else:
self.embeddings.weight.set_value(
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
)
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""
Defines the forward computation of the layer.
Args:
ids_remove_padding (Tensor, optional): Tensor of token IDs, with padding removed.
If None, no input is provided.
Returns:
Tensor: Embedded tensor representation of the input IDs.
"""
if self.use_ep:
input_embedings = self.embeddings(ids_remove_padding)
else:
if self.column_cut:
input_embedings = self.embeddings(ids_remove_padding)
inputs_embeds_temp = []
paddle.distributed.all_gather(
inputs_embeds_temp,
input_embedings,
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
sync_op=True,
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)
else:
input_embedings = self.embeddings(ids_remove_padding)
return input_embedings