[Feature] support qwen3-embedding model load (#4202)

* support qwen3-embedding

* fix ci bug

* fix

* fix ci bug

* fix ci bug

* fix
This commit is contained in:
lizexu123
2025-09-23 15:14:35 +08:00
committed by GitHub
parent 9082f625ba
commit c96a535a5d
5 changed files with 315 additions and 63 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
from dataclasses import dataclass
from typing import Dict
import numpy as np
@@ -22,9 +23,73 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn
from .utils import get_tensor
from .utils import (
DEFAULT_VOCAB_PADDING_SIZE,
get_tensor,
pad_vocab_size,
vocab_range_from_global_vocab_size,
)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
@property
def num_added_elements_padded(self) -> int:
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
class VocabParallelEmbedding(nn.Layer):
@@ -39,6 +104,7 @@ class VocabParallelEmbedding(nn.Layer):
embedding_dim: int = 768,
params_dtype: str = "bfloat16",
prefix="",
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
) -> None:
"""
Initialize the VocabParallelEmbedding layer for the model.
@@ -65,10 +131,32 @@ class VocabParallelEmbedding(nn.Layer):
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
self.padding_size = padding_size
self.org_vocab_size = num_embeddings
self.num_embeddings = num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(
self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
self.tensor_parallel_rank,
self.world_size,
)
if num_embeddings % self.world_size != 0:
self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size)
if not self.column_cut:
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
num_embeddings,
self.num_embeddings_padded,
embedding_dim,
mp_group=self.tp_group,
weight_attr=paddle.ParamAttr(
@@ -76,7 +164,7 @@ class VocabParallelEmbedding(nn.Layer):
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader})
else:
# column cut embedding
self.embeddings = nn.Embedding(
@@ -106,6 +194,88 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings.weight.set_value(weight_tensor)
@classmethod
def _get_indices(
cls,
vocab_size_paded: int,
org_vocab_size_padded: int,
vocab_size: int,
org_vocab_size: int,
tp_rank: int,
tp_size: int,
) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = vocab_range_from_global_vocab_size(
org_vocab_size_padded, tp_rank, tp_size
)
padded_added_vocab_start_index, padded_added_vocab_end_index = vocab_range_from_global_vocab_size(
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
)
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index,
padded_org_vocab_end_index,
padded_added_vocab_start_index,
padded_added_vocab_end_index,
org_vocab_start_index,
org_vocab_end_index,
added_vocab_start_index,
added_vocab_end_index,
)
def weight_loader(self, param, loaded_weight, shard_id=None):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
loaded_weight = get_tensor(loaded_weight)
if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.cast(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
if output_dim is None:
assert (
param.shape == loaded_weight.shape
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
param.set_value(loaded_weight)
return
start_idx = self.shard_indices.org_vocab_start_index
end_idx = self.shard_indices.org_vocab_end_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1))
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor)
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size, (
f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} "
f"!= org_vocab_size {self.org_vocab_size}"
)
shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx)
if output_dim == 0:
param[: shard_weight.shape[0]].copy_(shard_weight, False)
param[shard_weight.shape[0] :].fill_(0)
else:
param[:, : shard_weight.shape[1]].copy_(shard_weight, False)
param[:, shard_weight.shape[1] :].fill_(0)
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""
Defines the forward computation of the layer.

View File

@@ -22,6 +22,10 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import (
DEFAULT_VOCAB_PADDING_SIZE,
pad_vocab_size,
)
from fastdeploy.model_executor.utils import (
default_weight_loader,
set_weight_attrs,
@@ -44,6 +48,7 @@ class ParallelLMHead(nn.Layer):
prefix: str = "",
with_bias: bool = False,
dtype: str = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
) -> None:
"""
Parallelized LMhead.
@@ -68,6 +73,10 @@ class ParallelLMHead(nn.Layer):
self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.fd_config = fd_config
self.padding_size = padding_size
if num_embeddings % self.nranks != 0:
num_embeddings = pad_vocab_size(num_embeddings, self.padding_size)
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

View File

@@ -0,0 +1,15 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

View File

@@ -45,6 +45,14 @@ if cache_params != "none":
c8_state_dict = paddle.load(cache_params, return_numpy=True)
DEFAULT_VOCAB_PADDING_SIZE = 64
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
"""
Only used in deep_gemm block wise quant weight.
@@ -372,3 +380,14 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str])
paddle.Tensor: An empty tensor with the specified shape and data type.
"""
return paddle.empty(list(shape), dtype=dtype)
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, rank: int, offset: int = 0):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset)

View File

@@ -27,7 +27,9 @@ from fastdeploy.config import (
ModelConfig,
ParallelConfig,
)
from fastdeploy.model_executor.models.adapters import as_embedding_model
from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.scheduler import SchedulerConfig
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, ".."))
@@ -36,58 +38,103 @@ if project_root not in sys.path:
from tests.model_loader.utils import get_torch_model_path
test_model_configs = {
"Qwen3-0.6B": {
"tensor_parallel_size": 2,
"max_model_len": 8192,
"baseline_suffix": "standard",
},
"Qwen3-Embedding-0.6B": {
"tensor_parallel_size": 2,
"max_model_len": 8192,
"baseline_suffix": "embedding",
},
}
class TestModelLoader:
@pytest.fixture(scope="session", autouse=True)
def setup_paddle(self):
if not paddle.is_compiled_with_cuda():
print("CUDA not available, using CPU")
paddle.set_device("cpu")
else:
print("Using CUDA device")
raise AssertionError("CUDA not available")
paddle.set_device("gpu")
yield
@pytest.fixture(scope="session")
def model_path(self):
@pytest.fixture(scope="session", params=list(test_model_configs.keys()))
def model_info(self, request):
model_name = request.param
try:
torch_model_path = get_torch_model_path("Qwen3-0.6B")
if os.path.exists(torch_model_path):
return torch_model_path
torch_model_path = get_torch_model_path(model_name)
if not os.path.exists(torch_model_path):
raise AssertionError(f"Model path does not exist: {torch_model_path}")
return {"name": model_name, "path": torch_model_path, "config": test_model_configs[model_name]}
except Exception as e:
print(f"Could not get torch model path: {e}")
raise AssertionError(f"Could not get torch model path for {model_name}: {e}")
@pytest.fixture
def model_config(self, model_path):
def model_config(self, model_info):
if model_info is None:
raise AssertionError("model_info is None")
model_args = {
"model": model_path,
"model": model_info["path"],
"dtype": "bfloat16",
"max_model_len": 8192,
"tensor_parallel_size": 1,
"max_model_len": model_info["config"]["max_model_len"],
"tensor_parallel_size": model_info["config"]["tensor_parallel_size"],
"runner": "auto",
"convert": "auto",
}
try:
return ModelConfig(model_args)
config = ModelConfig(model_args)
return config
except Exception as e:
print(f"Could not create ModelConfig: {e}")
raise AssertionError(f"Could not create ModelConfig: {e}")
@pytest.fixture
def fd_config(self, model_config):
def scheduler_config(self):
scheduler_args = {
"name": "local",
"max_num_seqs": 256,
"max_num_batched_tokens": 8192,
"splitwise_role": "mixed",
"max_size": -1,
"ttl": 900,
"max_model_len": 8192,
"enable_chunked_prefill": False,
"max_num_partial_prefills": 1,
"max_long_partial_prefills": 1,
"long_prefill_token_threshold": 0,
}
try:
config = SchedulerConfig(scheduler_args)
return config
except Exception as e:
raise AssertionError(f"Could not create SchedulerConfig: {e}")
@pytest.fixture
def fd_config(self, model_info, model_config, scheduler_config):
if model_config is None:
raise AssertionError("ModelConfig is None")
if scheduler_config is None:
raise AssertionError("SchedulerConfig is None")
try:
tensor_parallel_size = model_info["config"]["tensor_parallel_size"]
cache_args = {
"block_size": 64,
"gpu_memory_utilization": 0.9,
"cache_dtype": "bfloat16",
"model_cfg": model_config,
"tensor_parallel_size": 1,
"tensor_parallel_size": tensor_parallel_size,
}
cache_config = CacheConfig(cache_args)
parallel_args = {
"tensor_parallel_size": 1,
"tensor_parallel_size": tensor_parallel_size,
"data_parallel_size": 1,
}
parallel_config = ParallelConfig(parallel_args)
@@ -95,88 +142,80 @@ class TestModelLoader:
load_args = {}
load_config = LoadConfig(load_args)
graph_opt_args = {
"enable_cudagraph": False,
"cudagraph_capture_sizes": None,
}
graph_opt_args = {}
graph_opt_config = GraphOptimizationConfig(graph_opt_args)
return FDConfig(
fd_config = FDConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
load_config=load_config,
graph_opt_config=graph_opt_config,
test_mode=True,
)
return fd_config
except Exception as e:
print(f"Could not create FDConfig: {e}")
raise AssertionError(f"Could not create FDConfig: {e}")
@pytest.fixture
def model_json_config(self, model_path):
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
def model_json_config(self, model_info):
if model_info is None:
raise AssertionError("model_info is None")
config_path = os.path.join(model_info["path"], "config.json")
if not os.path.exists(config_path):
raise AssertionError(f"Config file does not exist: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
return None
def test_embedding_with_none_convert_type(self, fd_config, model_json_config):
if model_json_config is None:
pytest.skip("Model config not available")
if fd_config is None:
pytest.skip("FDConfig not available")
print("=" * 60)
print("Testing initialize_model with convert_type='none'")
print("=" * 60)
def test_embedding_with_none_convert_type(self, model_info, fd_config, model_json_config):
if any(x is None for x in [model_info, fd_config, model_json_config]):
raise AssertionError("Required configs not available")
architectures = model_json_config.get("architectures", [])
if not architectures:
pytest.skip("No architectures found in model config")
raise AssertionError("No architectures found in model config")
fd_config.model_config.convert_type = "none"
try:
model_cls = ModelRegistry.get_class(architectures)
model_cls = ModelRegistry.get_class(architectures[0])
if hasattr(model_cls, "__name__"):
assert (
"ForEmbedding" not in model_cls.__name__
), f"Standard model should not have 'ForEmbedding' in name, but got: {model_cls.__name__}"
print(f"Confirmed standard model type (no ForEmbedding): {model_cls.__name__}")
standard_methods = set(dir(model_cls))
assert "_init_pooler" not in standard_methods, "Standard model should not have _init_pooler method"
except Exception as e:
print(f"Error in none: {e}")
raise AssertionError(f"Error in none convert type test: {e}")
def test_embedding_with_embed_convert_type(self, fd_config, model_json_config):
if model_json_config is None:
pytest.skip("Model config not available")
if fd_config is None:
pytest.skip("FDConfig not available")
print("=" * 60)
print("Testing embedding with convert_type='embed'")
print("=" * 60)
def test_embedding_with_embed_convert_type(self, model_info, fd_config, model_json_config):
if any(x is None for x in [model_info, fd_config, model_json_config]):
raise AssertionError("Required configs not available")
architectures = model_json_config.get("architectures", [])
if not architectures:
pytest.skip("No architectures found in model config")
raise AssertionError("No architectures found in model config")
fd_config.model_config.convert_type = "embed"
try:
model_cls = ModelRegistry.get_class(architectures)
model_cls = ModelRegistry.get_class(architectures[0])
model_cls = as_embedding_model(model_cls)
if hasattr(model_cls, "__name__"):
assert "ForEmbedding" in model_cls.__name__, "Embedding model should have 'ForEmbedding' in name"
print(f"Confirmed embedding model type: {model_cls.__name__}")
assert (
"ForEmbedding" in model_cls.__name__
), f"Embedding model should have 'ForEmbedding' in name, but got: {model_cls.__name__}"
embedding_methods = set(dir(model_cls))
assert "_init_pooler" in embedding_methods, "Embedding model should have _init_pooler method"
except Exception as e:
print(f"Error in convert embed: {e}")
raise AssertionError(f"Error in embed convert type test: {e}")