Files
FastDeploy/fastdeploy/model_executor/layers/embeddings.py
lizexu123 e8318b7477 [BugFix] fix qwen3-embedding model tp>1 (#4223)
* support qwen3-embedding

* fix ci bug

* fix

* fix ci bug

* fix ci bug

* fix

* fix qwen3-embedding

* fix

* fix

* fix
2025-09-24 14:13:26 +08:00

309 lines
12 KiB
Python

"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from typing import Dict
import numpy as np
import paddle
from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn
from .utils import (
DEFAULT_VOCAB_PADDING_SIZE,
get_tensor,
pad_vocab_size,
vocab_range_from_global_vocab_size,
)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
@property
def num_added_elements_padded(self) -> int:
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
class VocabParallelEmbedding(nn.Layer):
"""
VocabParallelEmbedding Layer
"""
def __init__(
self,
fd_config: FDConfig,
num_embeddings: int,
embedding_dim: int = 768,
params_dtype: str = "bfloat16",
prefix="",
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
) -> None:
"""
Initialize the VocabParallelEmbedding layer for the model.
Args:
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
num_embeddings (int) : vocabulary size.
embedding_dim (int) : size of hidden state.
params_dtype (str) : data type of parameters.
prefix (str): The name of current layer. Defaults to "".
"""
super().__init__()
self.fd_config = fd_config
hcg = fleet.get_hybrid_communicate_group()
self.mp_rank: int = hcg.get_model_parallel_rank()
self.column_cut = False
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
self.initializer_range: float = fd_config.model_config.initializer_range
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
self.padding_size = padding_size
self.org_vocab_size = num_embeddings
self.num_embeddings = num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(
self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
self.tensor_parallel_rank,
self.world_size,
)
if num_embeddings % self.world_size != 0:
self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size)
if not self.column_cut:
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
self.num_embeddings_padded,
embedding_dim,
mp_group=self.tp_group,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
if num_embeddings % self.world_size != 0:
set_weight_attrs(self.embeddings.weight, {"weight_loader", self.weight_loader})
else:
# column cut embedding
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim // self.world_size,
)
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
if self.tie_word_embeddings:
weight_tensor = get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
else:
weight_tensor = get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())
self.embeddings.weight.set_value(weight_tensor)
@classmethod
def _get_indices(
cls,
vocab_size_paded: int,
org_vocab_size_padded: int,
vocab_size: int,
org_vocab_size: int,
tp_rank: int,
tp_size: int,
) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = vocab_range_from_global_vocab_size(
org_vocab_size_padded, tp_rank, tp_size
)
padded_added_vocab_start_index, padded_added_vocab_end_index = vocab_range_from_global_vocab_size(
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
)
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index,
padded_org_vocab_end_index,
padded_added_vocab_start_index,
padded_added_vocab_end_index,
org_vocab_start_index,
org_vocab_end_index,
added_vocab_start_index,
added_vocab_end_index,
)
def weight_loader(self, param, loaded_weight, shard_id=None):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
if not param._is_initialized():
param.initialize()
loaded_weight = get_tensor(loaded_weight)
if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.cast(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
if output_dim is None:
assert (
param.shape == loaded_weight.shape
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
param.copy_(loaded_weight, False)
return
start_idx = self.shard_indices.org_vocab_start_index
end_idx = self.shard_indices.org_vocab_end_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1))
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor)
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size, (
f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} "
f"!= org_vocab_size {self.org_vocab_size}"
)
shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx)
if output_dim == 0:
param[: shard_weight.shape[0]].copy_(shard_weight, False)
param[shard_weight.shape[0] :].fill_(0)
else:
param[:, : shard_weight.shape[1]].copy_(shard_weight, False)
param[:, shard_weight.shape[1] :].fill_(0)
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""
Defines the forward computation of the layer.
Args:
ids_remove_padding (Tensor, optional): Tensor of token IDs, with padding removed.
If None, no input is provided.
Returns:
Tensor: Embedded tensor representation of the input IDs.
"""
if self.column_cut:
input_embedings = self.embeddings(ids_remove_padding)
inputs_embeds_temp = []
paddle.distributed.all_gather(
inputs_embeds_temp,
input_embedings,
group=self.tp_group,
sync_op=True,
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)
else:
input_embedings = self.embeddings(ids_remove_padding)
return input_embedings