[Feature] support pool (#3827)

* support pool

* update pooling

* add pooler_config and check

* update

* support AutoWeightsLoader load weight

* fix

* update

* delete print

* update pre-commit

* fix

* fix xpu

* fix ModelRegistry->model_registry

* fix Copilot review

* fix pooler.py

* delete StepPooler

* fix abstract

* fix default_loader_v1

* fix Pre Commit

* support torch qwen3 dense

* add test and fix torch-qwen

* fix

* fix

* adapter ci:

* fix review

* fix pooling_params.py

* fix

* fix tasks.py 2025

* fix print and logger

* Modefy ModelRegistry and delete AutoWeightsLoader

* fix logger

* fix test_embedding

* fix ci bug

* ernie4_5 model_registry

* fix test

* support Qwen3-Embedding-0.6B tp=1 load

* fix extra code

* fix

* delete fix vocab_size

* delete prepare_params_dict

* fix:
This commit is contained in:
lizexu123
2025-09-22 14:09:09 +08:00
committed by GitHub
parent da74a5f0b3
commit c86945ef49
36 changed files with 2371 additions and 51 deletions

View File

@@ -18,7 +18,7 @@ Assuming you have a custom model class `MyModelForCasualLM` and a pretrained cla
```python
# File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py
from fastdeploy.model_registry import ModelRegistry
from fastdeploy.model_executor.models.model_base import ModelRegistry
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
from fastdeploy.config import ErnieArchitectures

View File

@@ -18,7 +18,7 @@ FastDeploy 利用 Python 的 `entry_points` 机制来发现并加载插件。开
```python
# 文件fd_add_dummy_model/__init__.py
from fastdeploy.model_registry import ModelRegistry
from fastdeploy.model_executor.models.model_base import ModelRegistry
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
def register():

View File

@@ -18,12 +18,14 @@ from __future__ import annotations
import json
import os
from dataclasses import field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
import paddle
import paddle.distributed as dist
from paddleformers.transformers.configuration_utils import PretrainedConfig
from typing_extensions import assert_never
import fastdeploy
from fastdeploy import envs
@@ -31,11 +33,68 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfig
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.transformer_utils.config import get_pooling_config
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
logger = get_logger("config", "config.log")
TaskOption = Literal["generate"]
TaskOption = Literal["auto", "generate", "embedding", "embed"]
RunnerType = Literal["generate", "pooling"]
RunnerOption = Literal["auto", "generate", "pooling"]
ConvertOption = Literal["auto", "none", "embed"]
ConvertType = Literal["none", "embed"]
_ResolvedTask = Literal["generate", "encode", "embed"]
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [],
"pooling": ["embed"],
}
# Some model suffixes are based on auto classes from Transformers:
# https://huggingface.co/docs/transformers/en/model_doc/auto
# NOTE: Items higher on this list priority over lower ones
_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForCausalLM", ("generate", "none")),
("ForConditionalGeneration", ("generate", "none")),
("ChatModel", ("generate", "none")),
("LMHeadModel", ("generate", "none")),
("ForTextEncoding", ("pooling", "embed")),
("EmbeddingModel", ("pooling", "embed")),
("ForSequenceClassification", ("pooling", "classify")),
("ForAudioClassification", ("pooling", "classify")),
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),
("ClassificationModel", ("pooling", "classify")),
("ForRewardModeling", ("pooling", "reward")),
("RewardModel", ("pooling", "reward")),
# Let other `*Model`s take priority
("Model", ("pooling", "embed")),
]
def iter_architecture_defaults():
yield from _SUFFIX_TO_DEFAULTS
def try_match_architecture_defaults(
architecture: str,
*,
runner_type: Optional[RunnerType] = None,
convert_type: Optional[ConvertType] = None,
):
for suffix, (default_runner_type, default_convert_type) in iter_architecture_defaults():
if (
(runner_type is None or runner_type == default_runner_type)
and (convert_type is None or convert_type == default_convert_type)
and architecture.endswith(suffix)
):
return suffix, (default_runner_type, default_convert_type)
return None
class MoEPhase:
@@ -133,6 +192,12 @@ class ModelConfig:
self.eos_tokens_lens: int = 2
self.lm_head_fp32: bool = False
self.model_format = "auto"
self.runner = "auto"
self.convert = "auto"
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
self.revision = None
self.partial_rotary_factor: float = 1.0
self.num_nextn_predict_layers = 0
for key, value in args.items():
@@ -161,6 +226,7 @@ class ModelConfig:
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
architectures = self.architectures[0]
if MultimodalRegistry.contains_model(architectures):
self.enable_mm = True
else:
@@ -171,6 +237,43 @@ class ModelConfig:
self.override_name_from_config()
self.read_from_env()
self.read_model_config()
self.runner_type = self._get_runner_type(self.architectures, self.runner)
self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert)
registry = self.registry
is_generative_model = registry.is_text_generation_model(self.architectures, self)
is_pooling_model = registry.is_pooling_model(self.architectures, self)
is_multimodal_model = registry.is_multimodal_model(self.architectures, self)
if self.runner_type == "generate" and not is_generative_model:
if is_multimodal_model:
pass
else:
generate_converts = _RUNNER_CONVERTS["generate"]
if self.convert_type not in generate_converts:
raise ValueError("This model does not support '--runner generate.")
if self.runner_type == "pooling" and not is_pooling_model:
pooling_converts = _RUNNER_CONVERTS["pooling"]
if self.convert_type not in pooling_converts:
convert_option = "<" + "|".join(pooling_converts) + ">"
raise ValueError(
"This model does not support `--runner pooling`. "
f"You can pass `--convert {convert_option} to adapt "
"it into a pooling model."
)
self.supported_tasks = self._get_supported_tasks(self.architectures, self.runner_type, self.convert_type)
model_info, arch = registry.inspect_model_cls(self.architectures, self)
self._model_info = model_info
self._architecture = arch
self.pooler_config = self._init_pooler_config()
@property
def registry(self):
from fastdeploy.model_executor.models.model_base import ModelRegistry
return ModelRegistry()
def override_name_from_config(self):
"""
@@ -194,7 +297,6 @@ class ModelConfig:
def read_from_env(self):
"""
Read configuration information from environment variables and update the object's attributes.
If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
@@ -235,6 +337,165 @@ class ModelConfig:
f"Config file path: {config_path}"
)
def _get_default_runner_type(
self,
architectures: list[str],
) -> RunnerType:
registry = self.registry
if get_pooling_config(self.model, self.revision):
return "pooling"
for arch in architectures:
if arch in registry.get_supported_archs():
if registry.is_pooling_model(architectures, self):
return "pooling"
if registry.is_text_generation_model(architectures, self):
return "generate"
match = try_match_architecture_defaults(arch)
if match:
_, (runner_type, _) = match
return runner_type
return "generate"
def _get_default_convert_type(
self,
architectures: list[str],
runner_type: RunnerType,
) -> ConvertType:
registry = self.registry
for arch in architectures:
if arch in registry.get_supported_archs():
if runner_type == "generate" and registry.is_text_generation_model(architectures, self):
return "none"
if runner_type == "pooling" and registry.is_pooling_model(architectures, self):
return "none"
match = try_match_architecture_defaults(arch, runner_type=runner_type)
if match:
_, (_, convert_type) = match
return convert_type
# This is to handle Sentence Transformers models that use *ForCausalLM
# and also multi-modal pooling models which are not defined as
# Sentence Transformers models
if runner_type == "pooling":
return "embed"
return "none"
def _get_runner_type(
self,
architectures: list[str],
runner: RunnerOption,
) -> RunnerType:
if runner != "auto":
return runner
runner_type = self._get_default_runner_type(architectures)
if runner_type != "generate":
logger.info(
"Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.",
runner_type,
)
return runner_type
def _get_convert_type(
self,
architectures: list[str],
runner_type: RunnerType,
convert: ConvertOption,
) -> ConvertType:
if convert != "auto":
return convert
convert_type = self._get_default_convert_type(architectures, runner_type)
if convert_type != "none":
logger.info(
"Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.",
convert_type,
)
return convert_type
def _get_supported_generation_tasks(
self,
architectures: list[str],
convert_type: ConvertType,
) -> list[_ResolvedTask]:
registry = self.registry
supported_tasks = list[_ResolvedTask]()
if registry.is_text_generation_model(architectures, self) or convert_type in _RUNNER_CONVERTS["generate"]:
supported_tasks.append("generate")
# TODO:Temporarily does not support transcription.
return supported_tasks
def _get_default_pooling_task(
self,
architectures: list[str],
) -> Literal["embed"]:
# Temporarily does not support classification and reward.
for arch in architectures:
match = try_match_architecture_defaults(arch, runner_type="pooling")
if match:
_, (_, convert_type) = match
assert convert_type != "none"
return convert_type
return "embed"
def _get_supported_pooling_tasks(
self,
architectures: list[str],
convert_type: ConvertType,
) -> list[_ResolvedTask]:
registry = self.registry
supported_tasks = list[_ResolvedTask]()
if registry.is_pooling_model(architectures, self) or convert_type in _RUNNER_CONVERTS["pooling"]:
supported_tasks.append("encode")
extra_task = self._get_default_pooling_task(architectures) if convert_type == "none" else convert_type
supported_tasks.append(extra_task)
return supported_tasks
def _get_supported_tasks(
self,
architectures: list[str],
runner_type: RunnerType,
convert_type: ConvertType,
) -> list[_ResolvedTask]:
if runner_type == "generate":
return self._get_supported_generation_tasks(architectures, convert_type)
if runner_type == "pooling":
return self._get_supported_pooling_tasks(architectures, convert_type)
assert_never(runner_type)
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(**self.override_pooler_config)
pooler_config = self.override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
for k, v in base_config.items():
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type
if pooler_config.pooling_type is None:
pooler_config.pooling_type = default_pooling_type
return pooler_config
return None
def _get_download_model(self, model_name, model_type="default"):
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
pass
@@ -846,6 +1107,41 @@ class LoadConfig:
setattr(self, key, value)
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: Optional[str] = None
"""
The pooling method of the pooling model.
"""
# for embeddings models
normalize: Optional[bool] = None
"""
Whether to normalize the embeddings outputs. Defaults to True.
"""
dimensions: Optional[int] = None
"""
Reduce the dimensions of embeddings if model
support matryoshka representation. Defaults to None.
"""
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
When an input exceeds max_embed_len, it will be handled according to
the original max_model_len validation logic.
Defaults to None (i.e. set to max_model_len).
"""
class LoRAConfig:
"""LoRA Config"""

View File

@@ -18,13 +18,14 @@ import argparse
import json
from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import paddle
from fastdeploy import envs
from fastdeploy.config import (
CacheConfig,
ConvertOption,
EarlyStopConfig,
FDConfig,
GraphOptimizationConfig,
@@ -32,6 +33,8 @@ from fastdeploy.config import (
MobaAttentionConfig,
ModelConfig,
ParallelConfig,
PoolerConfig,
RunnerOption,
SpeculativeConfig,
TaskOption,
)
@@ -95,6 +98,20 @@ class EngineArgs:
"""
The task to be executed by the model.
"""
runner: RunnerOption = "auto"
"""
The type of model runner to use.Each FD instance only supports one model runner.
even if the same model can be used for multiple types.
"""
convert: ConvertOption = "auto"
"""
Convert the model using adapters. The most common use case is to
adapt a text generation model to be used for pooling tasks.
"""
override_pooler_config: Optional[Union[dict, PoolerConfig]] = None
"""
Override configuration for the pooler.
"""
max_num_seqs: int = 8
"""
Maximum number of sequences per iteration.
@@ -473,6 +490,21 @@ class EngineArgs:
default=EngineArgs.task,
help="Task to be executed by the model.",
)
model_group.add_argument(
"--runner",
type=str,
default=EngineArgs.runner,
help="The type of model runner to use",
)
model_group.add_argument(
"--convert", type=str, default=EngineArgs.convert, help="Convert the model using adapters"
)
model_group.add_argument(
"--override-pooler-config",
type=json.loads,
default=EngineArgs.override_pooler_config,
help="Override the pooler configuration with a JSON string.",
)
model_group.add_argument(
"--use-warmup",
type=int,

View File

@@ -498,6 +498,9 @@ class LLMEngine:
f" --load_choices {self.cfg.load_config.load_choices}"
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --ips {ips}"
f" --runner {self.cfg.model_config.runner}"
f" --convert {self.cfg.model_config.convert}"
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
)
worker_append_flag = {

View File

@@ -0,0 +1,170 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from copy import deepcopy
from typing import TYPE_CHECKING, Annotated, Any, Optional
import msgspec
from fastdeploy.engine.sampling_params import RequestOutputKind
from fastdeploy.engine.tasks import PoolingTask
if TYPE_CHECKING:
from fastdeploy.config import ModelConfig
class PoolingParams:
"""API parameters for pooling models.
Attributes:
normalize: Whether to normalize the embeddings outputs.
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
activation: Whether to apply activation function to
the classification outputs.
softmax: Whether to apply softmax to the reward outputs.
step_tag_id: Step tag ID for process reward models to identify
specific steps in multi-step reasoning tasks.
returned_token_ids: List of token IDs to return rewards for,
used for fine-grained reward calculation.
task: Internal use only. Specifies the pooling task type
("embed" for embeddings, "encode" for reward models).
requires_token_ids: Internal use only. Whether token ID information
is required for processing.
extra_kwargs: Internal use only. Dictionary for storing additional
custom parameters for extended functionality.
output_kind: Output type specification, fixed to FINAL_ONLY
(only final outputs are returned).
"""
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None
"""If set to -1, will use the truncation size supported by the model. If
set to an integer k, will use only the last k tokens from the prompt
(i.e., left truncation). If set to `None`, truncation is disabled."""
# for embeddings models
dimensions: Optional[int] = None
normalize: Optional[bool] = None
# for reward models
softmax: Optional[bool] = None
step_tag_id: Optional[int] = None
returned_token_ids: Optional[list[int]] = None
task: Optional[PoolingTask] = None
"""Internal use only."""
requires_token_ids: bool = False
"""Internal use only."""
extra_kwargs: Optional[dict[str, Any]] = None
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@property
def _all_parameters(self) -> list[str]:
return ["dimensions", "normalize", "softmax", "step_tag_id", "returned_token_ids"]
@property
def valid_parameters(self):
return {
"embed": ["dimensions", "normalize"],
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
}
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return deepcopy(self)
def verify(self, task: PoolingTask, model_config: Optional["ModelConfig"] = None) -> None:
if self.task is None:
self.task = task
elif self.task != task:
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
raise ValueError(msg)
# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# in this method
self._merge_default_parameters(model_config)
self._set_default_parameters(model_config)
self._verify_valid_parameters()
def _merge_default_parameters(self, model_config: Optional["ModelConfig"] = None) -> None:
if model_config is None:
return
pooler_config = model_config.pooler_config
if pooler_config is None:
return
assert self.task is not None, "task must be set"
valid_parameters = self.valid_parameters[self.task]
for k in valid_parameters:
if getattr(pooler_config, k, None) is None:
continue
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
if self.task == "embed":
if self.normalize is None:
self.normalize = True
elif self.task == "encode":
if self.softmax is None:
self.softmax = True
else:
raise ValueError(f"Unknown pooling task: {self.task}")
def _verify_valid_parameters(self):
assert self.task is not None, "task must be set"
valid_parameters = self.valid_parameters[self.task]
invalid_parameters = []
for k in self._all_parameters:
if k in valid_parameters:
continue
if getattr(self, k, None) is not None:
invalid_parameters.append(k)
if invalid_parameters:
raise ValueError(
f"Task {self.task} only supports {valid_parameters} "
f"parameters, does not support "
f"{invalid_parameters} parameters"
)
def __repr__(self) -> str:
return (
f"PoolingParams("
f"task={self.task}, "
f"normalize={self.normalize}, "
f"dimensions={self.dimensions}, "
f"softmax={self.softmax}, "
f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, "
f"extra_kwargs={self.extra_kwargs})"
)
def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY, "For pooling output_kind has to be FINAL_ONLY"

View File

@@ -18,6 +18,7 @@ from __future__ import annotations
import random
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, List, Optional, Union
@@ -268,3 +269,12 @@ class GuidedDecodingParams:
"You can only use one kind of guided decoding "
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
)
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOutput
FINAL_ONLY = 2

View File

@@ -0,0 +1,25 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Literal, get_args
GenerationTask = Literal["generate"]
GENERATION_TASKS = get_args(GenerationTask)
PoolingTask = Literal["encode", "embed"]
POOLING_TASKS = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask]

View File

@@ -146,3 +146,26 @@ class SiluAndMul(nn.Layer):
if self.bias is not None:
out = out + self.bias
return out
def get_act_fn(act_fn_name: str) -> nn.Layer:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name.startswith("paddle.nn.Layer"):
activation_name = act_fn_name.split(".")[-1]
if activation_name == "identity":
return nn.Identity()
act_fn_name = activation_name
activation_map = {
"gelu": nn.GELU(),
"relu": nn.ReLU(),
"silu": nn.Silu(),
"tanh": nn.Tanh(),
"sigmoid": nn.Sigmoid(),
}
if act_fn_name in activation_map:
return activation_map[act_fn_name]
else:
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")

View File

@@ -0,0 +1,85 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from typing import Optional
import paddle
from fastdeploy.engine.pooling_params import PoolingParams
@dataclass
class PoolingCursor:
index: list[int]
first_token_indices_gpu: paddle.Tensor
last_token_indices_gpu: paddle.Tensor
prompt_lens_cpu: paddle.Tensor
num_scheduled_tokens_cpu: paddle.Tensor
def __getitem__(self, indices: slice):
return PoolingCursor(
index=self.index[indices],
first_token_indices_gpu=self.first_token_indices_gpu[indices],
last_token_indices_gpu=self.last_token_indices_gpu[indices],
prompt_lens_cpu=self.prompt_lens_cpu[indices],
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
)
def is_partial_prefill(self):
return not paddle.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu).item()
@dataclass
class PoolingMetadata:
"""Tensors for pooling."""
prompt_lens: paddle.Tensor # CPU Tensor
prompt_token_ids: Optional[paddle.Tensor]
pooling_params: list[PoolingParams]
pooling_cursor: Optional[PoolingCursor] = None
def __getitem__(self, indices: slice):
return PoolingMetadata(
prompt_lens=self.prompt_lens[indices],
prompt_token_ids=None if self.prompt_token_ids is None else self.prompt_token_ids[indices],
pooling_params=self.pooling_params[indices],
pooling_cursor=None if self.pooling_cursor is None else self.pooling_cursor[indices],
)
def build_pooling_cursor(self, num_scheduled_tokens: list[int], device: str):
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, self.prompt_lens, device)
def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Tensor, device: str):
assert len(prompt_lens) == len(num_scheduled_tokens)
n_seq = len(num_scheduled_tokens)
index = list(range(n_seq))
num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens, device="cpu")
cumsum = paddle.zeros([n_seq + 1], dtype="int64", place=paddle.CPUPlace())
paddle.cumsum(num_scheduled_tokens, axis=0, out=cumsum[1:])
if device == "gpu":
cumsum_device = cumsum.cuda()
else:
cumsum_device = cumsum
return PoolingCursor(
index=index,
first_token_indices_gpu=cumsum_device[:n_seq],
last_token_indices_gpu=cumsum_device[1:] - 1,
prompt_lens_cpu=prompt_lens,
num_scheduled_tokens_cpu=num_scheduled_tokens,
)

View File

@@ -0,0 +1,550 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import ABC, abstractmethod
from collections.abc import Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union, cast
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from fastdeploy.config import FDConfig, ModelConfig, PoolerConfig
from fastdeploy.engine.tasks import PoolingTask
from fastdeploy.model_executor.layers.pool.metadata import (
PoolingCursor,
PoolingMetadata,
PoolingParams,
)
from fastdeploy.model_executor.models.adapters import _load_st_projector
from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput
from fastdeploy.utils import get_logger
logger = get_logger("pooler", "pooler.log")
PoolingFn = Callable[
[Union[paddle.Tensor, list[paddle.Tensor]], PoolingMetadata], Union[paddle.Tensor, list[paddle.Tensor]]
]
ClassifierFn = Callable[[paddle.Tensor], paddle.Tensor]
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
ALL = 1
CLS = 2
STEP = 3
MEAN = 4
_T = TypeVar("_T", paddle.Tensor, list[paddle.Tensor])
@dataclass(frozen=True)
class ResolvedPoolingConfig:
pooling_type: PoolingType
task: PoolingTask
@classmethod
def from_config(
cls,
task: PoolingTask,
pooler_config: PoolerConfig,
) -> "ResolvedPoolingConfig":
assert pooler_config.pooling_type is not None
return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
pooling_params = pooling_metadata.pooling_params
return pooling_params
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
pooling_params = get_pooling_params(pooling_metadata)
tasks: list[PoolingTask] = [task for pooling_param in pooling_params if (task := pooling_param.task) is not None]
assert len(pooling_params) == len(tasks)
return tasks
def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[paddle.Tensor]:
assert (
pooling_metadata.prompt_token_ids is not None
), "Please set `requires_token_ids=True` in `get_pooling_updates`"
return [pooling_metadata.prompt_token_ids[i, :num] for i, num in enumerate(pooling_metadata.prompt_lens)]
@dataclass(frozen=True)
class PoolingParamsUpdate:
requires_token_ids: bool = False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def apply(self, params: PoolingParams) -> None:
params.requires_token_ids = self.requires_token_ids
class Pooler(nn.Layer, ABC):
"""The interface required for all poolers used in pooling models in FastDeploy."""
@staticmethod
def for_encode(pooler_config: PoolerConfig, model_config: Optional["ModelConfig"] = None):
if pooler_config.pooling_type == "STEP":
return StepPooler()
resolved_config = ResolvedPoolingConfig(task="encode", pooling_type=PoolingType.ALL)
return SimplePooler.from_config(resolved_config, model_config)
@staticmethod
def for_embed(pooler_config: PoolerConfig, model_config: Optional["ModelConfig"] = None):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
return SimplePooler.from_config(resolved_config, model_config)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classify: Optional[ClassifierFn],
):
pass
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: Union[list[paddle.Tensor], paddle.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
class BasePoolerActication(nn.Layer, ABC):
@abstractmethod
def forward(self, pooled_data: _T) -> _T:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
raise NotImplementedError
class PoolerActivation(BasePoolerActication):
@staticmethod
def wraps(module: nn.Layer):
if isinstance(module, nn.Identity):
return PoolerIdentity()
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
return PoolerClassify()
return LambdaPoolerActivation(module)
@abstractmethod
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
raise NotImplementedError
def forward(self, pooled_data: _T) -> _T:
if isinstance(pooled_data, list):
return [self.forward_chunk(data) for data in pooled_data]
return self.forward_chunk(pooled_data)
class PoolerIdentity(PoolerActivation):
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
return pooled_data
class PoolerClassify(PoolerActivation):
def __init__(self, *, static_num_labels: bool = True) -> None:
super().__init__()
if static_num_labels:
fd_config = FDConfig()
self.num_labels = getattr(fd_config.model_config, "num_labels", 0)
if self.num_labels == 0:
logger.warning(
"num_labels should be > 0 for classification"
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
else:
self.num_labels = None
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
num_labels = self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
if num_labels < 2:
return F.sigmoid(pooled_data.astype("float32")).astype(pooled_data.dtype)
return F.softmax(pooled_data.astype("float32"), axis=-1).astype(pooled_data.dtype)
class LambdaPoolerActivation(PoolerActivation):
def __init__(self, fn: Callable[[paddle.Tensor], paddle.Tensor]):
super().__init__()
self.fn = fn
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
return self.fn(pooled_data)
class PoolerHead(nn.Layer):
def __init__(self, activation: PoolerActivation) -> None:
super().__init__()
self.activation = activation
def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata):
return self.activation(pooled_data)
class EmbeddingPoolerHead(PoolerHead):
def __init__(self, model_config: Optional["ModelConfig"] = None) -> None:
super().__init__(activation=PoolerNormalize())
self.projector = _load_st_projector(model_config)
def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata):
if isinstance(pooled_data, list):
pooled_data = paddle.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
# Apply ST projector
if self.projector is not None:
projector = cast(nn.Layer, self.projector)
def _proj(x: paddle.Tensor) -> paddle.Tensor:
orig_dtype = x.dtype
y = projector(x.astype("float32"))
return y.astype(orig_dtype)
pooled_data = _proj(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params = get_pooling_params(pooling_metadata)
# for matryoshka representation
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
# if all dimensions are the same
d = dimensions_list[0]
pooled_data = pooled_data[..., :d]
else:
pooled_data = [vecs if d is None else vecs[..., :d] for vecs, d in zip(pooled_data, dimensions_list)]
# for normalize
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [self.activation(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)]
# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data
class RewardPoolerHead(PoolerHead):
def __init__(self, model_config: Optional["ModelConfig"] = None) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False))
self.model_config = model_config
def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata):
pooling_params = get_pooling_params(pooling_metadata)
# for softmax
flags = [p.softmax for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [self.activation(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)]
return pooled_data
def build_output(
all_data: Union[paddle.Tensor, list[paddle.Tensor]],
) -> PoolerOutput:
# Pooling models D2H & synchronize occurs here
if isinstance(all_data, list):
all_data = [d.cpu() for d in all_data]
else:
all_data = all_data.cpu()
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
return PoolerOutput(outputs=all_outputs)
class PoolingMethod(nn.Layer, ABC):
@staticmethod
def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
if pooling_type == PoolingType.LAST:
return LastPool()
if pooling_type == PoolingType.ALL:
return AllPool()
if pooling_type == PoolingType.CLS:
return CLSPool()
if pooling_type == PoolingType.MEAN:
return MeanPool()
raise NotImplementedError(f"Unsupported method: {pooling_type}")
class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_all(
self,
hidden_states: paddle.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
return hidden_states[pooling_cursor.last_token_indices_gpu]
class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
def forward_all(
self,
hidden_states: paddle.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with ALL pooling"
hidden_states_lst = list(hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()))
return [hidden_states_lst[i] for i in pooling_cursor.index]
class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_all(
self,
hidden_states: paddle.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with MEAN pooling"
if hidden_states.place.is_gpu_place():
prompt_lens = pooling_cursor.prompt_lens_cpu.cuda()
else:
prompt_lens = pooling_cursor.prompt_lens_cpu
# Use float32 for paddle.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = paddle.cumsum(hidden_states.astype("float32"), axis=0)
start_indices = pooling_cursor.first_token_indices_gpu
end_indices = pooling_cursor.last_token_indices_gpu
return (cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_all(
self,
hidden_states: paddle.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with CLS pooling"
return hidden_states[pooling_cursor.first_token_indices_gpu]
class StepPooler(Pooler):
def __init__(
self,
) -> None:
super().__init__()
self.pooling = AllPool()
self.head = RewardPoolerHead()
def extract_states(
self,
hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[paddle.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(pooled_data_lst, prompt_token_ids, pooling_params):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
"""
@classmethod
def from_config(
cls,
pooler_config: ResolvedPoolingConfig,
model_config: Optional["ModelConfig"] = None,
) -> "SimplePooler":
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
if pooler_config.task == "embed":
head = EmbeddingPoolerHead(model_config)
elif pooler_config.task == "encode":
head = RewardPoolerHead(model_config)
else:
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
return cls(pooling, head)
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
super().__init__()
self.pooling = pooling
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def forward(
self,
hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
x = F.normalize(pooled_data.astype("float32"), p=2, axis=-1)
return x.astype(pooled_data.dtype)
class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task."""
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
super().__init__()
for task, pooler in poolers_by_task.items():
if task not in pooler.get_supported_tasks():
raise ValueError(
f"{pooler=} does not support {task=}. " f"Supported tasks: {pooler.get_supported_tasks()}"
)
self.poolers_by_task = poolers_by_task
def get_supported_tasks(self) -> Set[PoolingTask]:
return set(self.poolers_by_task)
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.poolers_by_task[task].get_pooling_updates(task)
def forward(
self,
hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
outputs = list[PoolingSequenceGroupOutput]()
offset = 0
for task, group in groupby(get_tasks(pooling_metadata)):
if not (pooler := poolers_by_task.get(task)):
raise ValueError(f"Unsupported task: {task} " f"Supported tasks: {self.get_supported_tasks()}")
num_items = len(list(group))
group_output: PoolerOutput = pooler(
hidden_states,
pooling_metadata[offset : offset + num_items],
)
outputs.extend(group_output.outputs)
offset += num_items
return PoolerOutput(outputs)

View File

@@ -61,6 +61,7 @@ class DefaultModelLoader(BaseModelLoader):
fd_config,
return_numpy=True,
)
model.set_state_dict(state_dict)
self.clean_memory_fragments(state_dict)

View File

@@ -16,7 +16,7 @@
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from typing_extensions import assert_never
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
from fastdeploy.model_executor.load_weight_utils import (
@@ -27,6 +27,7 @@ from fastdeploy.model_executor.load_weight_utils import (
save_model,
)
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.models.adapters import as_embedding_model
from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.platforms import current_platform
@@ -54,11 +55,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
load_weights_form_cache(model, weights_iterator)
else:
model.load_weights(weights_iterator)
self.clean_memory_fragments()
def load_model(self, fd_config: FDConfig) -> nn.Layer:
architectures = fd_config.model_config.architectures[0]
logger.info(f"Starting to load model {architectures}")
context = paddle.LazyGuard()
if fd_config.load_config.dynamic_load_weight:
# register rl model
@@ -70,6 +71,14 @@ class DefaultModelLoaderV1(BaseModelLoader):
with weight_cache_context:
with context:
model_cls = ModelRegistry.get_class(architectures)
convert_type = fd_config.model_config.convert_type
if convert_type == "none":
pass
elif convert_type == "embed":
model_cls = as_embedding_model(model_cls)
else:
assert_never(convert_type)
model = model_cls(fd_config)
model.eval()

View File

@@ -47,8 +47,10 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
module = importlib.import_module(f"{register_path}.{module_file}")
for attr_name in dir(module):
attr = getattr(module, attr_name)
if inspect.isclass(attr) and issubclass(attr, ModelForCasualLM) and attr is not ModelForCasualLM:
ModelRegistry.register_model_class(attr)
if (
inspect.isclass(attr)
and issubclass(attr, PretrainedModel)
@@ -56,6 +58,7 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
and hasattr(attr, "arch_name")
):
ModelRegistry.register_pretrained_model(attr)
except ImportError:
raise ImportError(f"{module_file=} import error")

View File

@@ -0,0 +1,214 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from collections.abc import Iterable
from typing import Optional, TypeVar
import paddle
import paddle.nn as nn
from fastdeploy.config import ModelConfig
from fastdeploy.model_executor.layers.activation import get_act_fn
from fastdeploy.model_executor.models.interfaces_base import is_pooling_model
from fastdeploy.transformer_utils.config import get_hf_file_to_dict
_T = TypeVar("_T", bound=type[nn.Layer])
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
"ChatModel",
"LMHeadModel",
]
def _load_dense_weights(linear: nn.Linear, folder: str, model_config: "ModelConfig") -> bool:
"""Load weights using vLLM's weight_loader pattern."""
from fastdeploy.model_executor.utils import default_weight_loader
filename = "model.safetensors"
file_path = f"{folder}/{filename}" if folder else filename
try:
file_bytes = get_hf_file_to_dict(file_path, model_config.model, model_config.revision)
if not file_bytes:
return False
state_dict = {}
if filename.endswith(".safetensors"):
import io
from safetensors.numpy import load as load_safetensors
numpy_tensors = load_safetensors(io.BytesIO(file_bytes))
for key, numpy_array in numpy_tensors.items():
state_dict[key] = paddle.to_tensor(numpy_array)
else:
import io
state_dict = paddle.load(io.BytesIO(file_bytes))
weight_keys = ["weight", "linear.weight", "dense.weight"]
for weight_key in weight_keys:
if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader", default_weight_loader)
weight_loader(linear.weight, state_dict[weight_key].astype(paddle.float32))
bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader", default_weight_loader)
bias_loader(linear.bias, state_dict[bias_key].astype(paddle.float32))
return True
except Exception as e:
print(f"Failed to load :{e}")
return False
return False
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Layer]:
try:
modules = get_hf_file_to_dict("modules.json", model_config.model, model_config.revision)
if not modules:
return None
if isinstance(modules, dict):
modules = modules.get("modules", [])
dense_modules = [m for m in modules if m.get("type") == "sentence_transformers.models.Dense"]
if not dense_modules:
return None
layers = []
for module in dense_modules:
folder = module.get("path", "")
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(config_path, model_config.model, model_config.revision)
if not layer_config:
continue
linear = nn.Linear(
layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
)
linear = linear.astype(paddle.float32)
if not _load_dense_weights(linear, folder, model_config):
continue
layers.append(linear)
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).astype(paddle.float32)
except Exception as e:
print(f"ST projector loading failed:{e}")
return None
def _create_pooling_model_cls(orig_cls: _T) -> _T:
class ModelForPooling(orig_cls):
def __init__(self, fd_config, *args, **kwargs):
super().__init__(fd_config, *args, **kwargs)
self.fd_config = fd_config
self.is_pooling_model = True
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "pooler", None):
self._init_pooler(fd_config)
def _init_pooler(self, fd_config):
raise NotImplementedError
def load_weights(self, weights: Iterable[tuple[str, paddle.Tensor]]):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights = ((name, data) for name, data in weights if not name.startswith("lm_head."))
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
# Whether only `self.model` contains parameters
model_is_only_param = all(
name == "model" or not any(child.parameters()) for name, child in self.named_children()
)
if model_is_only_param:
weights = ((name[6:], data) for name, data in weights if name.startswith("model."))
loaded_params = self.model.load_weights(weights)
loaded_params = {f"model.{name}" for name in loaded_params}
return loaded_params
# For most other models
if hasattr(orig_cls, "load_weights"):
return orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
raise ValueError("No load_weights method found in the model.")
return ModelForPooling
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name
for generate_suffix in _GENERATE_SUFFIXES:
model_name = model_name.removesuffix(generate_suffix)
return model_name + pooling_suffix
def as_embedding_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support embeddings.
By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
from fastdeploy.model_executor.layers.pooler import DispatchPooler, Pooler
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, fd_config, prefix: str = ""):
pooler_config = fd_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config, fd_config.model_config),
"embed": Pooler.for_embed(pooler_config, fd_config.model_config),
},
)
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
return ModelForEmbedding

View File

@@ -48,7 +48,11 @@ from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
)
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
@@ -588,6 +592,12 @@ class DeepSeekV3Model(nn.Layer):
return out
@ModelRegistry.register_model_class(
architecture="DeepseekV3ForCausalLM",
module_path="deepseek_v3",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class DeepseekV3ForCausalLM(ModelForCasualLM):
"""
DeepseekV3ForCausalLM

View File

@@ -45,7 +45,11 @@ from fastdeploy.model_executor.layers.linear import (
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta
@@ -471,6 +475,12 @@ class Ernie4_5_Model(nn.Layer):
return out
@ModelRegistry.register_model_class(
architecture="Ernie4_5_MoeForCausalLM",
module_path="ernie4_5_moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
"""
Ernie4_5_MoeForCausalLM
@@ -646,6 +656,12 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
@ModelRegistry.register_model_class(
architecture="Ernie4_5_ForCausalLM",
module_path="ernie4_5_moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
"""
Ernie4_5_ForCausalLM
@@ -659,6 +675,12 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
return "Ernie4_5_ForCausalLM"
@ModelRegistry.register_model_class(
architecture="Ernie4_5ForCausalLM",
module_path="ernie4_5_moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Ernie4_5ForCausalLM(Ernie4_5_ForCausalLM):
"""
Ernie4_5ForCausalLM 0.3B-PT

View File

@@ -31,7 +31,11 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
class Ernie4_5_MTPPretrainedModel(PretrainedModel):
@@ -325,6 +329,12 @@ class Ernie4_5_MTPModel(nn.Layer):
return hidden_states
@ModelRegistry.register_model_class(
architecture="Ernie4_5_MTPForCausalLM",
module_path="ernie4_5_mtp",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
"""
Ernie4_5_MTPForCausalLM

View File

@@ -44,7 +44,11 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (
Ernie4_5_Attention,
Ernie4_5_MLP,
)
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
@@ -792,6 +796,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
@ModelRegistry.register_model_class(
architecture="Ernie4_5_VLMoeForConditionalGeneration",
module_path="ernie4_5_vl.ernie4_5_vl_moe",
category=ModelCategory.MULTIMODAL,
primary_use=ModelCategory.MULTIMODAL,
)
class Ernie4_5_VLPretrainedModel(PretrainedModel):
"""
Ernie4_5_MoePretrainedModel

View File

@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
class Glm4MoeMLP(nn.Layer):
@@ -363,6 +367,12 @@ class Glm4MoeModel(nn.Layer):
return out
@ModelRegistry.register_model_class(
architecture="Glm4MoeForCausalLM",
module_path="glm4_moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Glm4MoeForCausalLM(ModelForCasualLM):
"""
Glm4MoeForCausalLM

View File

@@ -0,0 +1,54 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Type
from paddle import nn
def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:
from .model_base import ModelForCasualLM
return issubclass(model_cls, ModelForCasualLM)
def is_pooling_model(model_cls: Type[nn.Layer]) -> bool:
class_name = model_cls.__name__
pooling_indicators = ["Embedding", "ForSequenceClassification"]
return (
any(indicator in class_name for indicator in pooling_indicators)
or hasattr(model_cls, "is_embedding_model")
and model_cls.is_embedding_model
)
def is_multimodal_model(class_name: str) -> bool:
multimodal_indicators = ["VL", "Vision", "ConditionalGeneration"]
return any(indicator in class_name for indicator in multimodal_indicators)
def determine_model_category(class_name: str):
from fastdeploy.model_executor.models.model_base import ModelCategory
if any(pattern in class_name for pattern in ["VL", "Vision", "ConditionalGeneration"]):
return ModelCategory.MULTIMODAL
elif any(pattern in class_name for pattern in ["Embedding", "ForSequenceClassification"]):
return ModelCategory.EMBEDDING
return ModelCategory.TEXT_GENERATION
def get_default_pooling_type(model_cls: Type[nn.Layer] = None) -> str:
if model_cls is not None:
return getattr(model_cls, "default_pooling_type", "LAST")
return "LAST"

View File

@@ -3,40 +3,269 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,31 +11,265 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import importlib
from abc import ABC, abstractmethod
from typing import Dict, Union
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from fastdeploy.config import (
ModelConfig,
iter_architecture_defaults,
try_match_architecture_defaults,
)
from fastdeploy.model_executor.models.interfaces_base import (
determine_model_category,
get_default_pooling_type,
is_multimodal_model,
is_pooling_model,
is_text_generation_model,
)
class ModelCategory(Enum):
TEXT_GENERATION = "text_generation"
MULTIMODAL = "multimodal"
EMBEDDING = "embedding"
@dataclass(frozen=True)
class ModelInfo:
architecture: str
category: ModelCategory
is_text_generation: bool
is_multimodal: bool
is_pooling: bool
module_path: str
default_pooling_type: str
@staticmethod
def from_model_cls(model_cls: Type[nn.Layer], module_path: str = "") -> "ModelInfo":
return ModelInfo(
architecture=model_cls.__name__,
category=determine_model_category(model_cls.__name__),
is_text_generation=is_text_generation_model(model_cls),
is_multimodal=is_multimodal_model(model_cls.__name__),
is_pooling=is_pooling_model(model_cls),
default_pooling_type=get_default_pooling_type(model_cls),
module_path=module_path,
)
class BaseRegisteredModel(ABC):
"""Base class for registered models"""
@abstractmethod
def load_model_cls(self) -> Type[nn.Layer]:
raise NotImplementedError
@abstractmethod
def inspect_model_cls(self) -> ModelInfo:
raise NotImplementedError
@dataclass(frozen=True)
class LazyRegisteredModel(BaseRegisteredModel):
"""Lazy loaded model"""
module_name: str
class_name: str
def load_model_cls(self) -> Type[nn.Layer]:
try:
full_module = f"fastdeploy.model_executor.models.{self.module_name}"
module = importlib.import_module(full_module)
return getattr(module, self.class_name)
except (ImportError, AttributeError) as e:
raise ImportError(f"Failed to load {self.class_name}: {e}")
def inspect_model_cls(self) -> ModelInfo:
model_cls = self.load_model_cls()
return ModelInfo.from_model_cls(model_cls, self.module_name)
@dataclass(frozen=True)
class RegisteredModel(BaseRegisteredModel):
model_cls: Type[nn.Layer]
def load_model_cls(self) -> Type[nn.Layer]:
return self.model_cls
def inspect_model_cls(self) -> ModelInfo:
return ModelInfo.from_model_cls(self.model_cls)
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
model_arch: str,
model: BaseRegisteredModel,
) -> Optional[ModelInfo]:
try:
return model.inspect_model_cls()
except Exception:
print("Error in inspecting model architecture '%s'", model_arch)
return None
class ModelRegistry:
"""
Used to register and retrieve model classes.
"""
_arch_to_model_cls = {}
_arch_to_pretrained_model_cls = {}
_enhanced_models: Dict[str, Dict] = {}
def __init__(self):
self.models: Dict[str, BaseRegisteredModel] = {}
self.pretrained_models: Dict[str, Type[PretrainedModel]] = {}
self._registered_models: Dict[str, BaseRegisteredModel] = {}
self._register_enhanced_models()
def _register_enhanced_models(self):
for arch, model_info in self._enhanced_models.items():
model = LazyRegisteredModel(module_name=model_info["module_path"], class_name=model_info["class_name"])
self.models[arch] = model
self._registered_models[arch] = model
@lru_cache(maxsize=128)
def _try_load_model_cls(self, architecture: str) -> Optional[Type[nn.Layer]]:
if architecture not in self.models:
return None
try:
return self.models[architecture].load_model_cls()
except Exception as e:
print(f"Failed to load model {architecture}: {e}")
return None
@lru_cache(maxsize=128)
def _try_inspect_model_cls(self, model_arch: str) -> Optional[ModelInfo]:
if model_arch not in self.models:
return None
try:
return self.models[model_arch].inspect_model_cls()
except Exception as e:
print(f"Failed to inspect model {model_arch}: {e}")
return None
def _normalize_arch(self, architecture: str, model_config: ModelConfig) -> str:
if architecture in self.models:
return architecture
match = try_match_architecture_defaults(
architecture,
runner_type=getattr(model_config, "runner_type", None),
convert_type=getattr(model_config, "convert_type", None),
)
if match:
suffix, _ = match
for repl_suffix, _ in iter_architecture_defaults():
base_arch = architecture.replace(suffix, repl_suffix)
if base_arch in self.models:
return base_arch
return architecture
def _raise_for_unsupported(self, architectures: list[str]):
all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures):
raise ValueError(
f"Model architectures {architectures} failed to be inspected. "
"Please check the logs for more details."
)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}"
)
def inspect_model_cls(
self, architectures: Union[str, List[str]], model_config: ModelConfig = None
) -> Tuple[ModelInfo, str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
raise ValueError("No model architectures are specified")
for arch in architectures:
normalized_arch = self._normalize_arch(arch, model_config)
model_info = self._try_inspect_model_cls(normalized_arch)
if model_info is not None:
return (model_info, arch)
return self._raise_for_unsupported(architectures)
@classmethod
def register_model_class(cls, model_class):
"""register model class"""
if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM:
cls._arch_to_model_cls[model_class.name()] = model_class
return model_class
def register_model_class(
cls,
model_class=None,
*,
architecture: str = None,
module_path: str = None,
category: Union[ModelCategory, List[ModelCategory]] = ModelCategory.TEXT_GENERATION,
primary_use: ModelCategory = None,
):
"""
Enhanced model class registration supporting both traditional and decorator-style registration.
Can be used as:
1. Traditional decorator: @ModelRegistry.register_model_class
2. Enhanced decorator with metadata: @ModelRegistry.register_model_class(architecture="...", module_path="...")
Args:
model_class: The model class (when used as simple decorator)
architecture (str): Unique identifier for the model architecture
module_path (str): Relative path to the module containing the model
category: Model category or list of categories
primary_use: Primary category for multi-category models
"""
def _register(model_cls):
# Traditional registration for ModelForCasualLM subclasses
if issubclass(model_cls, ModelForCasualLM) and model_cls is not ModelForCasualLM:
cls._arch_to_model_cls[model_cls.name()] = model_cls
# Enhanced decorator-style registration
if architecture and module_path:
categories = category if isinstance(category, list) else [category]
# Register main entry
arch_key = architecture
cls._enhanced_models[arch_key] = {
"class_name": model_cls.__name__,
"module_path": module_path,
"category": primary_use or categories[0],
"class": model_cls,
}
# Register category-specific entries for multi-category models
if len(categories) > 1:
for cat in categories:
key = f"{arch_key}_{cat.value}"
cls._enhanced_models[key] = {
"class_name": model_cls.__name__,
"module_path": module_path,
"category": cat,
"primary_use": primary_use or categories[0],
"class": model_cls,
}
return model_cls
if model_class is not None:
return _register(model_class)
else:
return _register
@classmethod
def register_pretrained_model(cls, pretrained_model):
@@ -50,11 +279,6 @@ class ModelRegistry:
return pretrained_model
@classmethod
def get_pretrain_cls(cls, architectures: str):
"""get_pretrain_cls"""
return cls._arch_to_pretrained_model_cls[architectures]
@classmethod
def get_class(cls, name):
"""get model class"""
@@ -62,12 +286,61 @@ class ModelRegistry:
raise ValueError(f"Model '{name}' is not registered!")
return cls._arch_to_model_cls[name]
@classmethod
def get_pretrain_cls(cls, architectures: str):
"""get_pretrain_cls"""
return cls._arch_to_pretrained_model_cls[architectures]
@classmethod
def get_supported_archs(cls):
assert len(cls._arch_to_model_cls) >= len(
cls._arch_to_pretrained_model_cls
), "model class num is more than pretrained model registry num"
return [key for key in cls._arch_to_model_cls.keys()]
traditional_archs = list(cls._arch_to_model_cls.keys())
enhanced_archs = list(cls._enhanced_models.keys())
return traditional_archs + enhanced_archs
def resolve_model_cls(self, architectures: Union[str, List[str]]) -> Tuple[Type[nn.Layer], str]:
"""Resolve model class"""
if isinstance(architectures, str):
architectures = [architectures]
for arch in architectures:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return model_cls, arch
raise ValueError(f"Cannot find supported model: {architectures}")
def is_multimodal_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
"""Check if it's a multimodal model"""
if isinstance(architectures, str):
architectures = [architectures]
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info.is_multimodal
return False
def is_text_generation_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
"""Check if it's a text generation model"""
if isinstance(architectures, str):
architectures = [architectures]
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info.is_text_generation
return False
def is_pooling_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
"""Check if it's a pooling model"""
if isinstance(architectures, str):
architectures = [architectures]
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info.is_pooling
return False
class ModelForCasualLM(nn.Layer, ABC):
@@ -88,7 +361,6 @@ class ModelForCasualLM(nn.Layer, ABC):
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
"""
Load model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
@@ -105,12 +377,10 @@ class ModelForCasualLM(nn.Layer, ABC):
):
"""
Defines the forward pass of the model for generating text.
Args:
input_ids (Tensor, optional): The input token ids to the model.
pos_emb (Tensor, optional): position Embeddings for model.
**model_kwargs: Additional keyword arguments for the model.
Returns:
Tensor or list of Tensors: Generated tokens or decoded outputs.
"""

View File

@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
class Qwen2MLP(nn.Layer):
@@ -282,6 +286,12 @@ class Qwen2Model(nn.Layer):
return out
@ModelRegistry.register_model_class(
architecture="Qwen2ForCausalLM",
module_path="qwen2",
category=[ModelCategory.TEXT_GENERATION, ModelCategory.EMBEDDING],
primary_use=ModelCategory.TEXT_GENERATION,
)
class Qwen2ForCausalLM(ModelForCasualLM):
"""
Qwen2ForCausalLM

View File

@@ -33,7 +33,11 @@ from fastdeploy.model_executor.graph_optimization.decorator import (
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer
from fastdeploy.platforms import current_platform
@@ -157,6 +161,12 @@ class Qwen2_5_VLModel(nn.Layer):
return out
@ModelRegistry.register_model_class(
architecture="Qwen2_5_VLForConditionalGeneration",
module_path="qwen2_5_vl.qwen2_5_vl",
category=ModelCategory.MULTIMODAL,
primary_use=ModelCategory.MULTIMODAL,
)
class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
"""
Qwen2_5_VLForConditionalGeneration

View File

@@ -34,8 +34,13 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP
from fastdeploy.transformer_utils.config import get_pooling_config
class Qwen3MLP(Qwen2MLP):
@@ -218,6 +223,12 @@ class Qwen3Model(nn.Layer):
return out
@ModelRegistry.register_model_class(
architecture="Qwen3ForCausalLM",
module_path="qwen3",
category=[ModelCategory.TEXT_GENERATION],
primary_use=ModelCategory.TEXT_GENERATION,
)
class Qwen3ForCausalLM(ModelForCasualLM):
"""
Qwen3ForCausalLM
@@ -260,6 +271,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
process_weights_after_loading,
)
is_pooling_model = hasattr(self, "is_pooling_model") and self.is_pooling_model
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -270,8 +283,18 @@ class Qwen3ForCausalLM(ModelForCasualLM):
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
]
params_dict = dict(self.named_parameters())
model_path = self.fd_config.model_config.model
revision = self.fd_config.model_config.revision
if is_pooling_model and get_pooling_config(model_path, revision):
params_dict = {
param_name[6:] if param_name.startswith("model.") else param_name: param
for param_name, param in params_dict.items()
}
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
@@ -282,6 +305,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
model_param_name = loaded_weight_name
@@ -290,10 +314,11 @@ class Qwen3ForCausalLM(ModelForCasualLM):
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
if self.tie_word_embeddings and not is_pooling_model:
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})
@paddle.no_grad()

View File

@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
@@ -316,6 +320,12 @@ class Qwen3MoeModel(nn.Layer):
return out
@ModelRegistry.register_model_class(
architecture="Qwen3MoeForCausalLM",
module_path="qwen3moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Qwen3MoeForCausalLM(ModelForCasualLM):
"""
Qwen3MoeForCausalLM

View File

@@ -158,6 +158,7 @@ def default_weight_loader(fd_config: FDConfig) -> None:
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
"""fn"""
output_dim = getattr(param, "output_dim", None)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:

View File

@@ -0,0 +1,69 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any
import msgspec
import paddle
class PoolingSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True,
array_like=True,
):
"""The model output associated with a pooling sequence group."""
# Annotated as Any to be compatible with msgspec
# The actual type is in SequenceGroup.pooled_data
data: Any
def get_data_nbytes(self) -> int:
if isinstance(self.data, paddle.Tensor):
return self.data.numel() * self.data.element_size()
elif hasattr(self.data, "nbytes"):
return self.data.nbytes
else:
return 0
def __repr__(self) -> str:
return f"PoolingSequenceGroupOutput(data={self.data}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PoolingSequenceGroupOutput):
raise NotImplementedError()
return self.data == other.data
class PoolerOutput(msgspec.Struct, omit_defaults=True, array_like=True):
"""The output from a pooling operation in the pooling model."""
outputs: list[PoolingSequenceGroupOutput]
def get_data_nbytes(self) -> int:
return sum(o.get_data_nbytes() for o in self.outputs)
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx]
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self.outputs == other.outputs

View File

@@ -0,0 +1,15 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

View File

@@ -0,0 +1,139 @@
import json
from pathlib import Path
from typing import Optional, Union
import huggingface_hub
from huggingface_hub import hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import (
EntryNotFoundError,
HfHubHTTPError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from fastdeploy.utils import get_logger
logger = get_logger("transformer_config", "transformer_config.log")
def file_or_path_exists(model, config_name):
if (local_path := Path(model)).exists():
return (local_path / config_name).is_file()
return False
def get_pooling_config_name(pooling_name: str):
if "pooling_mode_" in pooling_name:
pooling_name = pooling_name.replace("pooling_mode_", "")
if "_" in pooling_name:
pooling_name = pooling_name.split("_")[0]
if "lasttoken" in pooling_name:
pooling_name = "last"
supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"]
pooling_type_name = pooling_name.upper()
if pooling_type_name in supported_pooling_types:
return pooling_type_name
raise NotImplementedError(f"Pooling type {pooling_type_name} not supported")
def try_get_local_file(model: Union[str, Path], file_name: str, revision: Optional[str] = "main") -> Optional[Path]:
file_path = Path(model) / file_name
if file_path.is_file():
return file_path
else:
try:
cached_filepath = try_to_load_from_cache(repo_id=model, filename=file_name, revision=revision)
if isinstance(cached_filepath, str):
return Path(cached_filepath)
except ValueError:
...
return None
def get_hf_file_to_dict(file_name: str, model: Union[str, Path], revision: Optional[str] = "main"):
"""
Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary.
Parameters:
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
Returns:
- config_dict (dict): A dictionary containing
the contents of the downloaded file.
"""
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
if file_path is None:
try:
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except huggingface_hub.errors.OfflineModeIsEnabled:
return None
except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.debug("File or repository not found in hf_hub_download", e)
return None
except HfHubHTTPError as e:
logger.warning(
"Cannot connect to Hugging Face Hub. Skipping file " "download for '%s':", file_name, exc_info=e
)
return None
file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path) as file:
return json.load(file)
return None
def get_pooling_config(model: str, revision: Optional[str] = "main"):
"""
This function gets the pooling and normalize
config from the model - only applies to
sentence-transformers models.
Args:
model (str): The name of the Hugging Face model.
revision (str, optional): The specific version
of the model to use. Defaults to 'main'.
Returns:
dict: A dictionary containing the pooling
type and whether normalization is used.
"""
modules_file_name = "modules.json"
modules_dict = None
if file_or_path_exists(model, config_name=modules_file_name):
modules_dict = get_hf_file_to_dict(modules_file_name, model)
if modules_dict is None:
return None
pooling = next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Pooling"), None)
normalize = bool(
next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Normalize"), False)
)
if pooling:
pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model)
pooling_type_name = next((item for item, val in pooling_dict.items() if val is True), None)
if pooling_type_name is not None:
pooling_type_name = get_pooling_config_name(pooling_type_name)
return {"pooling_type": pooling_type_name, "normalize": normalize}
return None

View File

@@ -51,6 +51,7 @@ from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
from fastdeploy.logger.logger import FastDeployLogger
T = TypeVar("T")
from typing import Callable, Optional
# [N,2] -> every line is [config_name, enable_xxx_name]
# Make sure enable_xxx equal to config.enable_xxx
@@ -852,3 +853,24 @@ api_server_logger = get_logger("api_server", "api_server.log")
console_logger = get_logger("console", "console.log", print_to_console=True)
spec_logger = get_logger("speculate", "speculate.log")
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
def _parse_type(val: str) -> T:
try:
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(f"Value {val} cannot be converted to {return_type}.") from e
return _parse_type
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
return parse_type(return_type)(val)
return _optional_type

View File

@@ -1319,8 +1319,12 @@ class GPUModelRunner(ModelRunnerBase):
self.parallel_config.max_model_len,
)
# 4. Execute spec decode
logits = self.model.compute_logits(hidden_states)
logits = None
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
pass
else:
# 4. Execute spec decode
logits = self.model.compute_logits(hidden_states)
if not self.speculative_decoding:
set_value_by_flags_and_idx(
@@ -1625,8 +1629,13 @@ class GPUModelRunner(ModelRunnerBase):
self.parallel_config.max_model_len,
)
logits = None
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
pass
else:
# 4. Execute spec decode
logits = self.model.compute_logits(hidden_states)
if not self.speculative_decoding:
set_value_by_flags_and_idx(

View File

@@ -45,7 +45,7 @@ from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import get_logger
from fastdeploy.utils import get_logger, optional_type
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("worker_process", "worker_process.log")
@@ -643,6 +643,27 @@ def parse_args():
help="Flag to specify dtype of lm_head as FP32",
)
parser.add_argument(
"--runner",
type=str,
default="auto",
help="The type of model runner to use.Each FD instance only supports one model runner.even if the same model can be used for multiple types.",
)
parser.add_argument(
"--convert",
type=str,
default="auto",
help="Convert the model using adapters. The most common use case is to adapt a text generation model to be used for pooling tasks.",
)
parser.add_argument(
"--override-pooler-config",
type=optional_type(json.loads),
default=None,
help="Override configuration for the pooler.",
)
args = parser.parse_args()
return args

View File

@@ -39,3 +39,4 @@ opentelemetry-distro 
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
partial_json_parser
msgspec

View File

@@ -14,9 +14,8 @@
from paddleformers.transformers import PretrainedModel
from fastdeploy import ModelRegistry
from fastdeploy.config import ErnieArchitectures
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.model_base import ModelForCasualLM, ModelRegistry
class MyPretrainedModel(PretrainedModel):

View File

@@ -0,0 +1,182 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import paddle
import pytest
from fastdeploy.config import (
CacheConfig,
FDConfig,
GraphOptimizationConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
)
from fastdeploy.model_executor.models.model_base import ModelRegistry
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from tests.model_loader.utils import get_torch_model_path
class TestModelLoader:
@pytest.fixture(scope="session", autouse=True)
def setup_paddle(self):
if not paddle.is_compiled_with_cuda():
print("CUDA not available, using CPU")
paddle.set_device("cpu")
else:
print("Using CUDA device")
paddle.set_device("gpu")
yield
@pytest.fixture(scope="session")
def model_path(self):
try:
torch_model_path = get_torch_model_path("Qwen3-0.6B")
if os.path.exists(torch_model_path):
return torch_model_path
except Exception as e:
print(f"Could not get torch model path: {e}")
@pytest.fixture
def model_config(self, model_path):
model_args = {
"model": model_path,
"dtype": "bfloat16",
"max_model_len": 8192,
"tensor_parallel_size": 1,
"runner": "auto",
"convert": "auto",
}
try:
return ModelConfig(model_args)
except Exception as e:
print(f"Could not create ModelConfig: {e}")
@pytest.fixture
def fd_config(self, model_config):
try:
cache_args = {
"block_size": 64,
"gpu_memory_utilization": 0.9,
"cache_dtype": "bfloat16",
"model_cfg": model_config,
"tensor_parallel_size": 1,
}
cache_config = CacheConfig(cache_args)
parallel_args = {
"tensor_parallel_size": 1,
"data_parallel_size": 1,
}
parallel_config = ParallelConfig(parallel_args)
load_args = {}
load_config = LoadConfig(load_args)
graph_opt_args = {
"enable_cudagraph": False,
"cudagraph_capture_sizes": None,
}
graph_opt_config = GraphOptimizationConfig(graph_opt_args)
return FDConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
load_config=load_config,
graph_opt_config=graph_opt_config,
test_mode=True,
)
except Exception as e:
print(f"Could not create FDConfig: {e}")
@pytest.fixture
def model_json_config(self, model_path):
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
return None
def test_embedding_with_none_convert_type(self, fd_config, model_json_config):
if model_json_config is None:
pytest.skip("Model config not available")
if fd_config is None:
pytest.skip("FDConfig not available")
print("=" * 60)
print("Testing initialize_model with convert_type='none'")
print("=" * 60)
architectures = model_json_config.get("architectures", [])
if not architectures:
pytest.skip("No architectures found in model config")
fd_config.model_config.convert_type = "none"
try:
model_cls = ModelRegistry.get_class(architectures)
if hasattr(model_cls, "__name__"):
assert (
"ForEmbedding" not in model_cls.__name__
), f"Standard model should not have 'ForEmbedding' in name, but got: {model_cls.__name__}"
print(f"Confirmed standard model type (no ForEmbedding): {model_cls.__name__}")
standard_methods = set(dir(model_cls))
assert "_init_pooler" not in standard_methods, "Standard model should not have _init_pooler method"
except Exception as e:
print(f"Error in none: {e}")
def test_embedding_with_embed_convert_type(self, fd_config, model_json_config):
if model_json_config is None:
pytest.skip("Model config not available")
if fd_config is None:
pytest.skip("FDConfig not available")
print("=" * 60)
print("Testing embedding with convert_type='embed'")
print("=" * 60)
architectures = model_json_config.get("architectures", [])
if not architectures:
pytest.skip("No architectures found in model config")
fd_config.model_config.convert_type = "embed"
try:
model_cls = ModelRegistry.get_class(architectures)
if hasattr(model_cls, "__name__"):
assert "ForEmbedding" in model_cls.__name__, "Embedding model should have 'ForEmbedding' in name"
print(f"Confirmed embedding model type: {model_cls.__name__}")
embedding_methods = set(dir(model_cls))
assert "_init_pooler" in embedding_methods, "Embedding model should have _init_pooler method"
except Exception as e:
print(f"Error in convert embed: {e}")