diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 43fbd76a8..7b97c53e5 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -14,6 +14,7 @@ # limitations under the License. """ +from dataclasses import dataclass from typing import Dict import numpy as np @@ -22,9 +23,73 @@ from paddle import nn from paddle.distributed import fleet from fastdeploy.config import FDConfig -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn -from .utils import get_tensor +from .utils import ( + DEFAULT_VOCAB_PADDING_SIZE, + get_tensor, + pad_vocab_size, + vocab_range_from_global_vocab_size, +) + + +@dataclass +class VocabParallelEmbeddingShardIndices: + """Indices for a shard of a vocab parallel embedding.""" + + padded_org_vocab_start_index: int + padded_org_vocab_end_index: int + padded_added_vocab_start_index: int + padded_added_vocab_end_index: int + + org_vocab_start_index: int + org_vocab_end_index: int + added_vocab_start_index: int + added_vocab_end_index: int + + @property + def num_org_elements(self) -> int: + return self.org_vocab_end_index - self.org_vocab_start_index + + @property + def num_added_elements(self) -> int: + return self.added_vocab_end_index - self.added_vocab_start_index + + @property + def num_org_elements_padded(self) -> int: + return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index + + @property + def num_added_elements_padded(self) -> int: + return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index + + @property + def num_org_vocab_padding(self) -> int: + return self.num_org_elements_padded - self.num_org_elements + + @property + def num_added_vocab_padding(self) -> int: + return self.num_added_elements_padded - self.num_added_elements + + @property + def num_elements_padded(self) -> int: + return self.num_org_elements_padded + self.num_added_elements_padded + + def __post_init__(self): + # sanity checks + assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index + assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index + + assert self.org_vocab_start_index <= self.org_vocab_end_index + assert self.added_vocab_start_index <= self.added_vocab_end_index + + assert self.org_vocab_start_index <= self.padded_org_vocab_start_index + assert self.added_vocab_start_index <= self.padded_added_vocab_start_index + assert self.org_vocab_end_index <= self.padded_org_vocab_end_index + assert self.added_vocab_end_index <= self.padded_added_vocab_end_index + + assert self.num_org_elements <= self.num_org_elements_padded + assert self.num_added_elements <= self.num_added_elements_padded class VocabParallelEmbedding(nn.Layer): @@ -39,6 +104,7 @@ class VocabParallelEmbedding(nn.Layer): embedding_dim: int = 768, params_dtype: str = "bfloat16", prefix="", + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, ) -> None: """ Initialize the VocabParallelEmbedding layer for the model. @@ -65,10 +131,32 @@ class VocabParallelEmbedding(nn.Layer): self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings self.params_dtype: str = params_dtype + self.padding_size = padding_size + + self.org_vocab_size = num_embeddings + self.num_embeddings = num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, self.padding_size + ) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + self.shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + self.tensor_parallel_rank, + self.world_size, + ) + + if num_embeddings % self.world_size != 0: + self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size) if not self.column_cut: self.embeddings = fleet.meta_parallel.VocabParallelEmbedding( - num_embeddings, + self.num_embeddings_padded, embedding_dim, mp_group=self.tp_group, weight_attr=paddle.ParamAttr( @@ -76,7 +164,7 @@ class VocabParallelEmbedding(nn.Layer): ), ) if self.world_size > 1: - set_weight_attrs(self.embeddings.weight, {"output_dim": False}) + set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader}) else: # column cut embedding self.embeddings = nn.Embedding( @@ -106,6 +194,88 @@ class VocabParallelEmbedding(nn.Layer): self.embeddings.weight.set_value(weight_tensor) + @classmethod + def _get_indices( + cls, + vocab_size_paded: int, + org_vocab_size_padded: int, + vocab_size: int, + org_vocab_size: int, + tp_rank: int, + tp_size: int, + ) -> VocabParallelEmbeddingShardIndices: + """Get start and end indices for vocab parallel embedding, following the + layout outlined in the class docstring, based on the given tp_rank and + tp_size.""" + + num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded + padded_org_vocab_start_index, padded_org_vocab_end_index = vocab_range_from_global_vocab_size( + org_vocab_size_padded, tp_rank, tp_size + ) + + padded_added_vocab_start_index, padded_added_vocab_end_index = vocab_range_from_global_vocab_size( + num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size + ) + # remove padding + org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size) + org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size) + added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) + return VocabParallelEmbeddingShardIndices( + padded_org_vocab_start_index, + padded_org_vocab_end_index, + padded_added_vocab_start_index, + padded_added_vocab_end_index, + org_vocab_start_index, + org_vocab_end_index, + added_vocab_start_index, + added_vocab_end_index, + ) + + def weight_loader(self, param, loaded_weight, shard_id=None): + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + loaded_weight = get_tensor(loaded_weight) + if param.dtype != loaded_weight.dtype: + if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn: + loaded_weight = loaded_weight.cast(param.dtype) + else: + loaded_weight = loaded_weight.cast(param.dtype) + + if output_dim is None: + assert ( + param.shape == loaded_weight.shape + ), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}" + param.set_value(loaded_weight) + return + + start_idx = self.shard_indices.org_vocab_start_index + end_idx = self.shard_indices.org_vocab_end_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1)) + assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size, ( + f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} " + f"!= org_vocab_size {self.org_vocab_size}" + ) + + shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx) + + if output_dim == 0: + param[: shard_weight.shape[0]].copy_(shard_weight, False) + param[shard_weight.shape[0] :].fill_(0) + else: + param[:, : shard_weight.shape[1]].copy_(shard_weight, False) + param[:, shard_weight.shape[1] :].fill_(0) + def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ Defines the forward computation of the layer. diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 57131b00a..ff1bdaa92 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -22,6 +22,10 @@ from paddle import nn from paddle.distributed import fleet from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.utils import ( + DEFAULT_VOCAB_PADDING_SIZE, + pad_vocab_size, +) from fastdeploy.model_executor.utils import ( default_weight_loader, set_weight_attrs, @@ -44,6 +48,7 @@ class ParallelLMHead(nn.Layer): prefix: str = "", with_bias: bool = False, dtype: str = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, ) -> None: """ Parallelized LMhead. @@ -68,6 +73,10 @@ class ParallelLMHead(nn.Layer): self.column_cut = True self.nranks = fd_config.parallel_config.tensor_parallel_size self.fd_config = fd_config + self.padding_size = padding_size + + if num_embeddings % self.nranks != 0: + num_embeddings = pad_vocab_size(num_embeddings, self.padding_size) ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear diff --git a/fastdeploy/model_executor/layers/pool/__init__.py b/fastdeploy/model_executor/layers/pool/__init__.py new file mode 100644 index 000000000..f4ede9062 --- /dev/null +++ b/fastdeploy/model_executor/layers/pool/__init__.py @@ -0,0 +1,15 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index 85de8ec4c..27bc770e8 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -45,6 +45,14 @@ if cache_params != "none": c8_state_dict = paddle.load(cache_params, return_numpy=True) +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]: """ Only used in deep_gemm block wise quant weight. @@ -372,3 +380,14 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str]) paddle.Tensor: An empty tensor with the specified shape and data type. """ return paddle.empty(list(shape), dtype=dtype) + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, rank: int, offset: int = 0): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f + offset, index_l + offset + + +def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset) diff --git a/tests/pooling/test_embedding.py b/tests/pooling/test_embedding.py index d609726e2..a548494dc 100644 --- a/tests/pooling/test_embedding.py +++ b/tests/pooling/test_embedding.py @@ -27,7 +27,9 @@ from fastdeploy.config import ( ModelConfig, ParallelConfig, ) +from fastdeploy.model_executor.models.adapters import as_embedding_model from fastdeploy.model_executor.models.model_base import ModelRegistry +from fastdeploy.scheduler import SchedulerConfig current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(current_dir, "..")) @@ -36,58 +38,103 @@ if project_root not in sys.path: from tests.model_loader.utils import get_torch_model_path +test_model_configs = { + "Qwen3-0.6B": { + "tensor_parallel_size": 2, + "max_model_len": 8192, + "baseline_suffix": "standard", + }, + "Qwen3-Embedding-0.6B": { + "tensor_parallel_size": 2, + "max_model_len": 8192, + "baseline_suffix": "embedding", + }, +} + class TestModelLoader: @pytest.fixture(scope="session", autouse=True) def setup_paddle(self): if not paddle.is_compiled_with_cuda(): - print("CUDA not available, using CPU") - paddle.set_device("cpu") - else: - print("Using CUDA device") - paddle.set_device("gpu") + raise AssertionError("CUDA not available") + paddle.set_device("gpu") yield - @pytest.fixture(scope="session") - def model_path(self): + @pytest.fixture(scope="session", params=list(test_model_configs.keys())) + def model_info(self, request): + model_name = request.param try: - torch_model_path = get_torch_model_path("Qwen3-0.6B") - if os.path.exists(torch_model_path): - return torch_model_path + torch_model_path = get_torch_model_path(model_name) + if not os.path.exists(torch_model_path): + raise AssertionError(f"Model path does not exist: {torch_model_path}") + return {"name": model_name, "path": torch_model_path, "config": test_model_configs[model_name]} except Exception as e: - print(f"Could not get torch model path: {e}") + raise AssertionError(f"Could not get torch model path for {model_name}: {e}") @pytest.fixture - def model_config(self, model_path): + def model_config(self, model_info): + if model_info is None: + raise AssertionError("model_info is None") + model_args = { - "model": model_path, + "model": model_info["path"], "dtype": "bfloat16", - "max_model_len": 8192, - "tensor_parallel_size": 1, + "max_model_len": model_info["config"]["max_model_len"], + "tensor_parallel_size": model_info["config"]["tensor_parallel_size"], "runner": "auto", "convert": "auto", } try: - return ModelConfig(model_args) + config = ModelConfig(model_args) + return config except Exception as e: - print(f"Could not create ModelConfig: {e}") + raise AssertionError(f"Could not create ModelConfig: {e}") @pytest.fixture - def fd_config(self, model_config): + def scheduler_config(self): + scheduler_args = { + "name": "local", + "max_num_seqs": 256, + "max_num_batched_tokens": 8192, + "splitwise_role": "mixed", + "max_size": -1, + "ttl": 900, + "max_model_len": 8192, + "enable_chunked_prefill": False, + "max_num_partial_prefills": 1, + "max_long_partial_prefills": 1, + "long_prefill_token_threshold": 0, + } + try: + config = SchedulerConfig(scheduler_args) + return config + except Exception as e: + raise AssertionError(f"Could not create SchedulerConfig: {e}") + + @pytest.fixture + def fd_config(self, model_info, model_config, scheduler_config): + if model_config is None: + raise AssertionError("ModelConfig is None") + if scheduler_config is None: + raise AssertionError("SchedulerConfig is None") + + try: + tensor_parallel_size = model_info["config"]["tensor_parallel_size"] + cache_args = { "block_size": 64, "gpu_memory_utilization": 0.9, "cache_dtype": "bfloat16", "model_cfg": model_config, - "tensor_parallel_size": 1, + "tensor_parallel_size": tensor_parallel_size, } cache_config = CacheConfig(cache_args) parallel_args = { - "tensor_parallel_size": 1, + "tensor_parallel_size": tensor_parallel_size, "data_parallel_size": 1, } parallel_config = ParallelConfig(parallel_args) @@ -95,88 +142,80 @@ class TestModelLoader: load_args = {} load_config = LoadConfig(load_args) - graph_opt_args = { - "enable_cudagraph": False, - "cudagraph_capture_sizes": None, - } + graph_opt_args = {} graph_opt_config = GraphOptimizationConfig(graph_opt_args) - return FDConfig( + fd_config = FDConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, + scheduler_config=scheduler_config, load_config=load_config, graph_opt_config=graph_opt_config, test_mode=True, ) + return fd_config + except Exception as e: - print(f"Could not create FDConfig: {e}") + raise AssertionError(f"Could not create FDConfig: {e}") @pytest.fixture - def model_json_config(self, model_path): - config_path = os.path.join(model_path, "config.json") - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - return json.load(f) - return None + def model_json_config(self, model_info): + if model_info is None: + raise AssertionError("model_info is None") - def test_embedding_with_none_convert_type(self, fd_config, model_json_config): - if model_json_config is None: - pytest.skip("Model config not available") + config_path = os.path.join(model_info["path"], "config.json") + if not os.path.exists(config_path): + raise AssertionError(f"Config file does not exist: {config_path}") - if fd_config is None: - pytest.skip("FDConfig not available") + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) - print("=" * 60) - print("Testing initialize_model with convert_type='none'") - print("=" * 60) + def test_embedding_with_none_convert_type(self, model_info, fd_config, model_json_config): + if any(x is None for x in [model_info, fd_config, model_json_config]): + raise AssertionError("Required configs not available") architectures = model_json_config.get("architectures", []) if not architectures: - pytest.skip("No architectures found in model config") + raise AssertionError("No architectures found in model config") fd_config.model_config.convert_type = "none" try: - model_cls = ModelRegistry.get_class(architectures) + model_cls = ModelRegistry.get_class(architectures[0]) if hasattr(model_cls, "__name__"): assert ( "ForEmbedding" not in model_cls.__name__ ), f"Standard model should not have 'ForEmbedding' in name, but got: {model_cls.__name__}" - print(f"Confirmed standard model type (no ForEmbedding): {model_cls.__name__}") standard_methods = set(dir(model_cls)) assert "_init_pooler" not in standard_methods, "Standard model should not have _init_pooler method" except Exception as e: - print(f"Error in none: {e}") + raise AssertionError(f"Error in none convert type test: {e}") - def test_embedding_with_embed_convert_type(self, fd_config, model_json_config): - if model_json_config is None: - pytest.skip("Model config not available") - - if fd_config is None: - pytest.skip("FDConfig not available") - - print("=" * 60) - print("Testing embedding with convert_type='embed'") - print("=" * 60) + def test_embedding_with_embed_convert_type(self, model_info, fd_config, model_json_config): + if any(x is None for x in [model_info, fd_config, model_json_config]): + raise AssertionError("Required configs not available") architectures = model_json_config.get("architectures", []) if not architectures: - pytest.skip("No architectures found in model config") + raise AssertionError("No architectures found in model config") fd_config.model_config.convert_type = "embed" try: - model_cls = ModelRegistry.get_class(architectures) + model_cls = ModelRegistry.get_class(architectures[0]) + model_cls = as_embedding_model(model_cls) + if hasattr(model_cls, "__name__"): - assert "ForEmbedding" in model_cls.__name__, "Embedding model should have 'ForEmbedding' in name" - print(f"Confirmed embedding model type: {model_cls.__name__}") + assert ( + "ForEmbedding" in model_cls.__name__ + ), f"Embedding model should have 'ForEmbedding' in name, but got: {model_cls.__name__}" embedding_methods = set(dir(model_cls)) assert "_init_pooler" in embedding_methods, "Embedding model should have _init_pooler method" except Exception as e: - print(f"Error in convert embed: {e}") + raise AssertionError(f"Error in embed convert type test: {e}")