mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Feature] support qwen3-embedding model load (#4202)
* support qwen3-embedding * fix ci bug * fix * fix ci bug * fix ci bug * fix
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
@@ -22,9 +23,73 @@ from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
||||
from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn
|
||||
|
||||
from .utils import get_tensor
|
||||
from .utils import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
get_tensor,
|
||||
pad_vocab_size,
|
||||
vocab_range_from_global_vocab_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabParallelEmbeddingShardIndices:
|
||||
"""Indices for a shard of a vocab parallel embedding."""
|
||||
|
||||
padded_org_vocab_start_index: int
|
||||
padded_org_vocab_end_index: int
|
||||
padded_added_vocab_start_index: int
|
||||
padded_added_vocab_end_index: int
|
||||
|
||||
org_vocab_start_index: int
|
||||
org_vocab_end_index: int
|
||||
added_vocab_start_index: int
|
||||
added_vocab_end_index: int
|
||||
|
||||
@property
|
||||
def num_org_elements(self) -> int:
|
||||
return self.org_vocab_end_index - self.org_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_added_elements(self) -> int:
|
||||
return self.added_vocab_end_index - self.added_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_org_elements_padded(self) -> int:
|
||||
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_added_elements_padded(self) -> int:
|
||||
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_org_vocab_padding(self) -> int:
|
||||
return self.num_org_elements_padded - self.num_org_elements
|
||||
|
||||
@property
|
||||
def num_added_vocab_padding(self) -> int:
|
||||
return self.num_added_elements_padded - self.num_added_elements
|
||||
|
||||
@property
|
||||
def num_elements_padded(self) -> int:
|
||||
return self.num_org_elements_padded + self.num_added_elements_padded
|
||||
|
||||
def __post_init__(self):
|
||||
# sanity checks
|
||||
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
|
||||
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
||||
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
||||
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
|
||||
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
||||
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
||||
|
||||
assert self.num_org_elements <= self.num_org_elements_padded
|
||||
assert self.num_added_elements <= self.num_added_elements_padded
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Layer):
|
||||
@@ -39,6 +104,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
embedding_dim: int = 768,
|
||||
params_dtype: str = "bfloat16",
|
||||
prefix="",
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the VocabParallelEmbedding layer for the model.
|
||||
@@ -65,10 +131,32 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
self.params_dtype: str = params_dtype
|
||||
self.padding_size = padding_size
|
||||
|
||||
self.org_vocab_size = num_embeddings
|
||||
self.num_embeddings = num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
|
||||
)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
self.shard_indices = self._get_indices(
|
||||
self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size,
|
||||
self.tensor_parallel_rank,
|
||||
self.world_size,
|
||||
)
|
||||
|
||||
if num_embeddings % self.world_size != 0:
|
||||
self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size)
|
||||
|
||||
if not self.column_cut:
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
self.num_embeddings_padded,
|
||||
embedding_dim,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=paddle.ParamAttr(
|
||||
@@ -76,7 +164,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
),
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader})
|
||||
else:
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
@@ -106,6 +194,88 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
|
||||
self.embeddings.weight.set_value(weight_tensor)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(
|
||||
cls,
|
||||
vocab_size_paded: int,
|
||||
org_vocab_size_padded: int,
|
||||
vocab_size: int,
|
||||
org_vocab_size: int,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
) -> VocabParallelEmbeddingShardIndices:
|
||||
"""Get start and end indices for vocab parallel embedding, following the
|
||||
layout outlined in the class docstring, based on the given tp_rank and
|
||||
tp_size."""
|
||||
|
||||
num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index = vocab_range_from_global_vocab_size(
|
||||
org_vocab_size_padded, tp_rank, tp_size
|
||||
)
|
||||
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index = vocab_range_from_global_vocab_size(
|
||||
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
|
||||
)
|
||||
# remove padding
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
|
||||
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
|
||||
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
||||
return VocabParallelEmbeddingShardIndices(
|
||||
padded_org_vocab_start_index,
|
||||
padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index,
|
||||
padded_added_vocab_end_index,
|
||||
org_vocab_start_index,
|
||||
org_vocab_end_index,
|
||||
added_vocab_start_index,
|
||||
added_vocab_end_index,
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, shard_id=None):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
if param.dtype != loaded_weight.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
|
||||
if output_dim is None:
|
||||
assert (
|
||||
param.shape == loaded_weight.shape
|
||||
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
|
||||
param.set_value(loaded_weight)
|
||||
return
|
||||
|
||||
start_idx = self.shard_indices.org_vocab_start_index
|
||||
end_idx = self.shard_indices.org_vocab_end_index
|
||||
shard_size = self.shard_indices.org_vocab_end_index - start_idx
|
||||
|
||||
# If param packed on the same dim we are sharding on, then
|
||||
# need to adjust offsets of loaded weight by pack_factor.
|
||||
if packed_dim is not None and packed_dim == output_dim:
|
||||
packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1))
|
||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor)
|
||||
start_idx = start_idx // packed_factor
|
||||
shard_size = shard_size // packed_factor
|
||||
else:
|
||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size, (
|
||||
f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} "
|
||||
f"!= org_vocab_size {self.org_vocab_size}"
|
||||
)
|
||||
|
||||
shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx)
|
||||
|
||||
if output_dim == 0:
|
||||
param[: shard_weight.shape[0]].copy_(shard_weight, False)
|
||||
param[shard_weight.shape[0] :].fill_(0)
|
||||
else:
|
||||
param[:, : shard_weight.shape[1]].copy_(shard_weight, False)
|
||||
param[:, shard_weight.shape[1] :].fill_(0)
|
||||
|
||||
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
Reference in New Issue
Block a user