Files
FastDeploy/fastdeploy/model_executor/layers/embeddings.py
Longzhi Wang a30a5b4216
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
[Model] tp+ep support v1_loader (#5600)
* [Model] tp+ep support v1_loader

* fix

* fix mtp_linear

* fix mtp_linear

* fix

* fix

* fix v0 loader

* fix

* Add get_tensor for EP

* fix linear weight_loader

* fix typo

* fix
2025-12-18 15:27:12 +08:00

320 lines
12 KiB
Python

"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from typing import Dict
import numpy as np
import paddle
from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import h2d_copy, set_weight_attrs, slice_fn
from fastdeploy.platforms import current_platform
from .utils import (
DEFAULT_VOCAB_PADDING_SIZE,
get_tensor,
pad_vocab_size,
vocab_range_from_global_vocab_size,
)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
@property
def num_added_elements_padded(self) -> int:
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
class VocabParallelEmbedding(nn.Layer):
"""
VocabParallelEmbedding Layer
"""
def __init__(
self,
fd_config: FDConfig,
num_embeddings: int,
embedding_dim: int = 768,
params_dtype: str = "bfloat16",
prefix="",
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
org_num_embeddings: int | None = None,
general=False,
) -> None:
"""
Initialize the VocabParallelEmbedding layer for the model.
Args:
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
num_embeddings (int) : vocabulary size.
embedding_dim (int) : size of hidden state.
params_dtype (str) : data type of parameters.
prefix (str): The name of current layer. Defaults to "".
"""
super().__init__()
self.fd_config = fd_config
hcg = fleet.get_hybrid_communicate_group()
self.mp_rank: int = hcg.get_model_parallel_rank()
self.column_cut = False
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
self.initializer_range: float = fd_config.model_config.initializer_range
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
self.embedding_dim = embedding_dim
self.general = general # used for general Embedding
self.num_embeddings = num_embeddings
self.padding_size = padding_size
if self.general:
self.org_vocab_size = num_embeddings
self.num_embeddings_padded = num_embeddings
self.org_vocab_size_padded = num_embeddings
else:
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(
self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
self.tensor_parallel_rank,
self.world_size,
)
if not self.column_cut:
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
self.num_embeddings_padded,
embedding_dim,
mp_group=self.tp_group,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
set_weight_attrs(self.embeddings.weight, {"weight_loader": self.weight_loader})
else:
# column cut embedding
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim // self.world_size,
)
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
if self.tie_word_embeddings and not self.general:
weight_tensor = get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
else:
weight_tensor = get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
self.embeddings.weight.set_value(weight_tensor)
@classmethod
def _get_indices(
cls,
vocab_size_paded: int,
org_vocab_size_padded: int,
vocab_size: int,
org_vocab_size: int,
tp_rank: int,
tp_size: int,
) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = vocab_range_from_global_vocab_size(
org_vocab_size_padded, tp_rank, tp_size
)
padded_added_vocab_start_index, padded_added_vocab_end_index = vocab_range_from_global_vocab_size(
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
)
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index,
padded_org_vocab_end_index,
padded_added_vocab_start_index,
padded_added_vocab_end_index,
org_vocab_start_index,
org_vocab_end_index,
added_vocab_start_index,
added_vocab_end_index,
)
def weight_loader(self, param, loaded_weight, shard_id=None):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
if not param._is_initialized():
param.initialize()
loaded_weight = get_tensor(loaded_weight)
if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.cast(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
if output_dim is None:
assert (
param.shape == loaded_weight.shape
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
param.copy_(loaded_weight, False)
return
start_idx = self.shard_indices.org_vocab_start_index
end_idx = self.shard_indices.org_vocab_end_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1))
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor)
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size, (
f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} "
f"!= org_vocab_size {self.org_vocab_size}"
)
shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx)
if output_dim == 0:
h2d_copy(param[: shard_weight.shape[0]], shard_weight)
if not current_platform.is_maca():
if param.shape[0] != shard_weight.shape[0]:
param[shard_weight.shape[0] :].fill_(0)
else:
h2d_copy(param[:, : shard_weight.shape[1]], shard_weight)
if param.shape[1] != shard_weight.shape[1]:
param[:, shard_weight.shape[1] :].fill_(0)
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""
Defines the forward computation of the layer.
Args:
ids_remove_padding (Tensor, optional): Tensor of token IDs, with padding removed.
If None, no input is provided.
Returns:
Tensor: Embedded tensor representation of the input IDs.
"""
if ids_remove_padding.shape[0] == 0:
return paddle.empty([0, self.embedding_dim], dtype=self.embeddings.weight.dtype)
if self.column_cut:
input_embedings = self.embeddings(ids_remove_padding)
inputs_embeds_temp = []
paddle.distributed.all_gather(
inputs_embeds_temp,
input_embedings,
group=self.tp_group,
sync_op=True,
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)
else:
input_embedings = self.embeddings(ids_remove_padding)
return input_embedings