[Feature] support qwen3-embedding model load (#4202)

* support qwen3-embedding

* fix ci bug

* fix

* fix ci bug

* fix ci bug

* fix
This commit is contained in:
lizexu123
2025-09-23 15:14:35 +08:00
committed by GitHub
parent 9082f625ba
commit c96a535a5d
5 changed files with 315 additions and 63 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
from dataclasses import dataclass
from typing import Dict
import numpy as np
@@ -22,9 +23,73 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn
from .utils import get_tensor
from .utils import (
DEFAULT_VOCAB_PADDING_SIZE,
get_tensor,
pad_vocab_size,
vocab_range_from_global_vocab_size,
)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
@property
def num_added_elements_padded(self) -> int:
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
class VocabParallelEmbedding(nn.Layer):
@@ -39,6 +104,7 @@ class VocabParallelEmbedding(nn.Layer):
embedding_dim: int = 768,
params_dtype: str = "bfloat16",
prefix="",
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
) -> None:
"""
Initialize the VocabParallelEmbedding layer for the model.
@@ -65,10 +131,32 @@ class VocabParallelEmbedding(nn.Layer):
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
self.padding_size = padding_size
self.org_vocab_size = num_embeddings
self.num_embeddings = num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(
self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
self.tensor_parallel_rank,
self.world_size,
)
if num_embeddings % self.world_size != 0:
self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size)
if not self.column_cut:
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
num_embeddings,
self.num_embeddings_padded,
embedding_dim,
mp_group=self.tp_group,
weight_attr=paddle.ParamAttr(
@@ -76,7 +164,7 @@ class VocabParallelEmbedding(nn.Layer):
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader})
else:
# column cut embedding
self.embeddings = nn.Embedding(
@@ -106,6 +194,88 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings.weight.set_value(weight_tensor)
@classmethod
def _get_indices(
cls,
vocab_size_paded: int,
org_vocab_size_padded: int,
vocab_size: int,
org_vocab_size: int,
tp_rank: int,
tp_size: int,
) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = vocab_range_from_global_vocab_size(
org_vocab_size_padded, tp_rank, tp_size
)
padded_added_vocab_start_index, padded_added_vocab_end_index = vocab_range_from_global_vocab_size(
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
)
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index,
padded_org_vocab_end_index,
padded_added_vocab_start_index,
padded_added_vocab_end_index,
org_vocab_start_index,
org_vocab_end_index,
added_vocab_start_index,
added_vocab_end_index,
)
def weight_loader(self, param, loaded_weight, shard_id=None):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
loaded_weight = get_tensor(loaded_weight)
if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.cast(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
if output_dim is None:
assert (
param.shape == loaded_weight.shape
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
param.set_value(loaded_weight)
return
start_idx = self.shard_indices.org_vocab_start_index
end_idx = self.shard_indices.org_vocab_end_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1))
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor)
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size, (
f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} "
f"!= org_vocab_size {self.org_vocab_size}"
)
shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx)
if output_dim == 0:
param[: shard_weight.shape[0]].copy_(shard_weight, False)
param[shard_weight.shape[0] :].fill_(0)
else:
param[:, : shard_weight.shape[1]].copy_(shard_weight, False)
param[:, shard_weight.shape[1] :].fill_(0)
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""
Defines the forward computation of the layer.