mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feature] support pool (#3827)
* support pool * update pooling * add pooler_config and check * update * support AutoWeightsLoader load weight * fix * update * delete print * update pre-commit * fix * fix xpu * fix ModelRegistry->model_registry * fix Copilot review * fix pooler.py * delete StepPooler * fix abstract * fix default_loader_v1 * fix Pre Commit * support torch qwen3 dense * add test and fix torch-qwen * fix * fix * adapter ci: * fix review * fix pooling_params.py * fix * fix tasks.py 2025 * fix print and logger * Modefy ModelRegistry and delete AutoWeightsLoader * fix logger * fix test_embedding * fix ci bug * ernie4_5 model_registry * fix test * support Qwen3-Embedding-0.6B tp=1 load * fix extra code * fix * delete fix vocab_size * delete prepare_params_dict * fix:
This commit is contained in:
@@ -18,7 +18,7 @@ Assuming you have a custom model class `MyModelForCasualLM` and a pretrained cla
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py
|
# File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py
|
||||||
from fastdeploy.model_registry import ModelRegistry
|
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||||
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
|
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
|
||||||
from fastdeploy.config import ErnieArchitectures
|
from fastdeploy.config import ErnieArchitectures
|
||||||
|
|
||||||
|
@@ -18,7 +18,7 @@ FastDeploy 利用 Python 的 `entry_points` 机制来发现并加载插件。开
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# 文件:fd_add_dummy_model/__init__.py
|
# 文件:fd_add_dummy_model/__init__.py
|
||||||
from fastdeploy.model_registry import ModelRegistry
|
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||||
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
|
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
|
||||||
|
|
||||||
def register():
|
def register():
|
||||||
|
@@ -18,12 +18,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
@@ -31,11 +33,68 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfig
|
|||||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.scheduler import SchedulerConfig
|
from fastdeploy.scheduler import SchedulerConfig
|
||||||
|
from fastdeploy.transformer_utils.config import get_pooling_config
|
||||||
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
|
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
|
||||||
|
|
||||||
logger = get_logger("config", "config.log")
|
logger = get_logger("config", "config.log")
|
||||||
|
|
||||||
TaskOption = Literal["generate"]
|
TaskOption = Literal["auto", "generate", "embedding", "embed"]
|
||||||
|
|
||||||
|
RunnerType = Literal["generate", "pooling"]
|
||||||
|
|
||||||
|
RunnerOption = Literal["auto", "generate", "pooling"]
|
||||||
|
|
||||||
|
ConvertOption = Literal["auto", "none", "embed"]
|
||||||
|
|
||||||
|
ConvertType = Literal["none", "embed"]
|
||||||
|
|
||||||
|
_ResolvedTask = Literal["generate", "encode", "embed"]
|
||||||
|
|
||||||
|
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
|
||||||
|
"generate": [],
|
||||||
|
"pooling": ["embed"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Some model suffixes are based on auto classes from Transformers:
|
||||||
|
# https://huggingface.co/docs/transformers/en/model_doc/auto
|
||||||
|
# NOTE: Items higher on this list priority over lower ones
|
||||||
|
_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
|
||||||
|
("ForCausalLM", ("generate", "none")),
|
||||||
|
("ForConditionalGeneration", ("generate", "none")),
|
||||||
|
("ChatModel", ("generate", "none")),
|
||||||
|
("LMHeadModel", ("generate", "none")),
|
||||||
|
("ForTextEncoding", ("pooling", "embed")),
|
||||||
|
("EmbeddingModel", ("pooling", "embed")),
|
||||||
|
("ForSequenceClassification", ("pooling", "classify")),
|
||||||
|
("ForAudioClassification", ("pooling", "classify")),
|
||||||
|
("ForImageClassification", ("pooling", "classify")),
|
||||||
|
("ForVideoClassification", ("pooling", "classify")),
|
||||||
|
("ClassificationModel", ("pooling", "classify")),
|
||||||
|
("ForRewardModeling", ("pooling", "reward")),
|
||||||
|
("RewardModel", ("pooling", "reward")),
|
||||||
|
# Let other `*Model`s take priority
|
||||||
|
("Model", ("pooling", "embed")),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def iter_architecture_defaults():
|
||||||
|
yield from _SUFFIX_TO_DEFAULTS
|
||||||
|
|
||||||
|
|
||||||
|
def try_match_architecture_defaults(
|
||||||
|
architecture: str,
|
||||||
|
*,
|
||||||
|
runner_type: Optional[RunnerType] = None,
|
||||||
|
convert_type: Optional[ConvertType] = None,
|
||||||
|
):
|
||||||
|
for suffix, (default_runner_type, default_convert_type) in iter_architecture_defaults():
|
||||||
|
if (
|
||||||
|
(runner_type is None or runner_type == default_runner_type)
|
||||||
|
and (convert_type is None or convert_type == default_convert_type)
|
||||||
|
and architecture.endswith(suffix)
|
||||||
|
):
|
||||||
|
return suffix, (default_runner_type, default_convert_type)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class MoEPhase:
|
class MoEPhase:
|
||||||
@@ -133,6 +192,12 @@ class ModelConfig:
|
|||||||
self.eos_tokens_lens: int = 2
|
self.eos_tokens_lens: int = 2
|
||||||
self.lm_head_fp32: bool = False
|
self.lm_head_fp32: bool = False
|
||||||
self.model_format = "auto"
|
self.model_format = "auto"
|
||||||
|
self.runner = "auto"
|
||||||
|
self.convert = "auto"
|
||||||
|
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
|
||||||
|
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
|
||||||
|
self.revision = None
|
||||||
|
|
||||||
self.partial_rotary_factor: float = 1.0
|
self.partial_rotary_factor: float = 1.0
|
||||||
self.num_nextn_predict_layers = 0
|
self.num_nextn_predict_layers = 0
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
@@ -161,6 +226,7 @@ class ModelConfig:
|
|||||||
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
|
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
|
||||||
|
|
||||||
architectures = self.architectures[0]
|
architectures = self.architectures[0]
|
||||||
|
|
||||||
if MultimodalRegistry.contains_model(architectures):
|
if MultimodalRegistry.contains_model(architectures):
|
||||||
self.enable_mm = True
|
self.enable_mm = True
|
||||||
else:
|
else:
|
||||||
@@ -171,6 +237,43 @@ class ModelConfig:
|
|||||||
self.override_name_from_config()
|
self.override_name_from_config()
|
||||||
self.read_from_env()
|
self.read_from_env()
|
||||||
self.read_model_config()
|
self.read_model_config()
|
||||||
|
self.runner_type = self._get_runner_type(self.architectures, self.runner)
|
||||||
|
self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert)
|
||||||
|
|
||||||
|
registry = self.registry
|
||||||
|
is_generative_model = registry.is_text_generation_model(self.architectures, self)
|
||||||
|
is_pooling_model = registry.is_pooling_model(self.architectures, self)
|
||||||
|
is_multimodal_model = registry.is_multimodal_model(self.architectures, self)
|
||||||
|
|
||||||
|
if self.runner_type == "generate" and not is_generative_model:
|
||||||
|
if is_multimodal_model:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
generate_converts = _RUNNER_CONVERTS["generate"]
|
||||||
|
if self.convert_type not in generate_converts:
|
||||||
|
raise ValueError("This model does not support '--runner generate.")
|
||||||
|
if self.runner_type == "pooling" and not is_pooling_model:
|
||||||
|
pooling_converts = _RUNNER_CONVERTS["pooling"]
|
||||||
|
if self.convert_type not in pooling_converts:
|
||||||
|
convert_option = "<" + "|".join(pooling_converts) + ">"
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not support `--runner pooling`. "
|
||||||
|
f"You can pass `--convert {convert_option} to adapt "
|
||||||
|
"it into a pooling model."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.supported_tasks = self._get_supported_tasks(self.architectures, self.runner_type, self.convert_type)
|
||||||
|
model_info, arch = registry.inspect_model_cls(self.architectures, self)
|
||||||
|
self._model_info = model_info
|
||||||
|
self._architecture = arch
|
||||||
|
|
||||||
|
self.pooler_config = self._init_pooler_config()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def registry(self):
|
||||||
|
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||||
|
|
||||||
|
return ModelRegistry()
|
||||||
|
|
||||||
def override_name_from_config(self):
|
def override_name_from_config(self):
|
||||||
"""
|
"""
|
||||||
@@ -194,7 +297,6 @@ class ModelConfig:
|
|||||||
def read_from_env(self):
|
def read_from_env(self):
|
||||||
"""
|
"""
|
||||||
Read configuration information from environment variables and update the object's attributes.
|
Read configuration information from environment variables and update the object's attributes.
|
||||||
|
|
||||||
If an attribute is not present or is an empty string in the environment variables, use the default value.
|
If an attribute is not present or is an empty string in the environment variables, use the default value.
|
||||||
"""
|
"""
|
||||||
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
|
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
|
||||||
@@ -235,6 +337,165 @@ class ModelConfig:
|
|||||||
f"Config file path: {config_path}"
|
f"Config file path: {config_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_default_runner_type(
|
||||||
|
self,
|
||||||
|
architectures: list[str],
|
||||||
|
) -> RunnerType:
|
||||||
|
registry = self.registry
|
||||||
|
if get_pooling_config(self.model, self.revision):
|
||||||
|
return "pooling"
|
||||||
|
for arch in architectures:
|
||||||
|
if arch in registry.get_supported_archs():
|
||||||
|
if registry.is_pooling_model(architectures, self):
|
||||||
|
return "pooling"
|
||||||
|
if registry.is_text_generation_model(architectures, self):
|
||||||
|
return "generate"
|
||||||
|
match = try_match_architecture_defaults(arch)
|
||||||
|
if match:
|
||||||
|
_, (runner_type, _) = match
|
||||||
|
return runner_type
|
||||||
|
return "generate"
|
||||||
|
|
||||||
|
def _get_default_convert_type(
|
||||||
|
self,
|
||||||
|
architectures: list[str],
|
||||||
|
runner_type: RunnerType,
|
||||||
|
) -> ConvertType:
|
||||||
|
registry = self.registry
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
if arch in registry.get_supported_archs():
|
||||||
|
if runner_type == "generate" and registry.is_text_generation_model(architectures, self):
|
||||||
|
return "none"
|
||||||
|
if runner_type == "pooling" and registry.is_pooling_model(architectures, self):
|
||||||
|
return "none"
|
||||||
|
match = try_match_architecture_defaults(arch, runner_type=runner_type)
|
||||||
|
if match:
|
||||||
|
_, (_, convert_type) = match
|
||||||
|
return convert_type
|
||||||
|
|
||||||
|
# This is to handle Sentence Transformers models that use *ForCausalLM
|
||||||
|
# and also multi-modal pooling models which are not defined as
|
||||||
|
# Sentence Transformers models
|
||||||
|
if runner_type == "pooling":
|
||||||
|
return "embed"
|
||||||
|
|
||||||
|
return "none"
|
||||||
|
|
||||||
|
def _get_runner_type(
|
||||||
|
self,
|
||||||
|
architectures: list[str],
|
||||||
|
runner: RunnerOption,
|
||||||
|
) -> RunnerType:
|
||||||
|
if runner != "auto":
|
||||||
|
return runner
|
||||||
|
|
||||||
|
runner_type = self._get_default_runner_type(architectures)
|
||||||
|
if runner_type != "generate":
|
||||||
|
logger.info(
|
||||||
|
"Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.",
|
||||||
|
runner_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return runner_type
|
||||||
|
|
||||||
|
def _get_convert_type(
|
||||||
|
self,
|
||||||
|
architectures: list[str],
|
||||||
|
runner_type: RunnerType,
|
||||||
|
convert: ConvertOption,
|
||||||
|
) -> ConvertType:
|
||||||
|
if convert != "auto":
|
||||||
|
return convert
|
||||||
|
|
||||||
|
convert_type = self._get_default_convert_type(architectures, runner_type)
|
||||||
|
|
||||||
|
if convert_type != "none":
|
||||||
|
logger.info(
|
||||||
|
"Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.",
|
||||||
|
convert_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return convert_type
|
||||||
|
|
||||||
|
def _get_supported_generation_tasks(
|
||||||
|
self,
|
||||||
|
architectures: list[str],
|
||||||
|
convert_type: ConvertType,
|
||||||
|
) -> list[_ResolvedTask]:
|
||||||
|
registry = self.registry
|
||||||
|
|
||||||
|
supported_tasks = list[_ResolvedTask]()
|
||||||
|
if registry.is_text_generation_model(architectures, self) or convert_type in _RUNNER_CONVERTS["generate"]:
|
||||||
|
supported_tasks.append("generate")
|
||||||
|
|
||||||
|
# TODO:Temporarily does not support transcription.
|
||||||
|
return supported_tasks
|
||||||
|
|
||||||
|
def _get_default_pooling_task(
|
||||||
|
self,
|
||||||
|
architectures: list[str],
|
||||||
|
) -> Literal["embed"]:
|
||||||
|
# Temporarily does not support classification and reward.
|
||||||
|
for arch in architectures:
|
||||||
|
match = try_match_architecture_defaults(arch, runner_type="pooling")
|
||||||
|
if match:
|
||||||
|
_, (_, convert_type) = match
|
||||||
|
assert convert_type != "none"
|
||||||
|
return convert_type
|
||||||
|
|
||||||
|
return "embed"
|
||||||
|
|
||||||
|
def _get_supported_pooling_tasks(
|
||||||
|
self,
|
||||||
|
architectures: list[str],
|
||||||
|
convert_type: ConvertType,
|
||||||
|
) -> list[_ResolvedTask]:
|
||||||
|
registry = self.registry
|
||||||
|
|
||||||
|
supported_tasks = list[_ResolvedTask]()
|
||||||
|
if registry.is_pooling_model(architectures, self) or convert_type in _RUNNER_CONVERTS["pooling"]:
|
||||||
|
supported_tasks.append("encode")
|
||||||
|
|
||||||
|
extra_task = self._get_default_pooling_task(architectures) if convert_type == "none" else convert_type
|
||||||
|
supported_tasks.append(extra_task)
|
||||||
|
|
||||||
|
return supported_tasks
|
||||||
|
|
||||||
|
def _get_supported_tasks(
|
||||||
|
self,
|
||||||
|
architectures: list[str],
|
||||||
|
runner_type: RunnerType,
|
||||||
|
convert_type: ConvertType,
|
||||||
|
) -> list[_ResolvedTask]:
|
||||||
|
if runner_type == "generate":
|
||||||
|
return self._get_supported_generation_tasks(architectures, convert_type)
|
||||||
|
if runner_type == "pooling":
|
||||||
|
return self._get_supported_pooling_tasks(architectures, convert_type)
|
||||||
|
|
||||||
|
assert_never(runner_type)
|
||||||
|
|
||||||
|
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
||||||
|
if self.runner_type == "pooling":
|
||||||
|
if isinstance(self.override_pooler_config, dict):
|
||||||
|
self.override_pooler_config = PoolerConfig(**self.override_pooler_config)
|
||||||
|
|
||||||
|
pooler_config = self.override_pooler_config or PoolerConfig()
|
||||||
|
|
||||||
|
base_config = get_pooling_config(self.model, self.revision)
|
||||||
|
if base_config is not None:
|
||||||
|
for k, v in base_config.items():
|
||||||
|
if getattr(pooler_config, k) is None:
|
||||||
|
setattr(pooler_config, k, v)
|
||||||
|
|
||||||
|
default_pooling_type = self._model_info.default_pooling_type
|
||||||
|
if pooler_config.pooling_type is None:
|
||||||
|
pooler_config.pooling_type = default_pooling_type
|
||||||
|
|
||||||
|
return pooler_config
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _get_download_model(self, model_name, model_type="default"):
|
def _get_download_model(self, model_name, model_type="default"):
|
||||||
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
|
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
|
||||||
pass
|
pass
|
||||||
@@ -846,6 +1107,41 @@ class LoadConfig:
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerConfig:
|
||||||
|
"""Controls the behavior of output pooling in pooling models."""
|
||||||
|
|
||||||
|
pooling_type: Optional[str] = None
|
||||||
|
"""
|
||||||
|
The pooling method of the pooling model.
|
||||||
|
"""
|
||||||
|
# for embeddings models
|
||||||
|
normalize: Optional[bool] = None
|
||||||
|
"""
|
||||||
|
Whether to normalize the embeddings outputs. Defaults to True.
|
||||||
|
"""
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
"""
|
||||||
|
Reduce the dimensions of embeddings if model
|
||||||
|
support matryoshka representation. Defaults to None.
|
||||||
|
"""
|
||||||
|
enable_chunked_processing: Optional[bool] = None
|
||||||
|
"""
|
||||||
|
Whether to enable chunked processing for long inputs that exceed the model's
|
||||||
|
maximum position embeddings. When enabled, long inputs will be split into
|
||||||
|
chunks, processed separately, and then aggregated using weighted averaging.
|
||||||
|
This allows embedding models to handle arbitrarily long text without CUDA
|
||||||
|
errors. Defaults to False.
|
||||||
|
"""
|
||||||
|
max_embed_len: Optional[int] = None
|
||||||
|
"""
|
||||||
|
Maximum input length allowed for embedding generation. When set, allows
|
||||||
|
inputs longer than max_embed_len to be accepted for embedding models.
|
||||||
|
When an input exceeds max_embed_len, it will be handled according to
|
||||||
|
the original max_model_len validation logic.
|
||||||
|
Defaults to None (i.e. set to max_model_len).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LoRAConfig:
|
class LoRAConfig:
|
||||||
"""LoRA Config"""
|
"""LoRA Config"""
|
||||||
|
|
||||||
|
@@ -18,13 +18,14 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from dataclasses import fields as dataclass_fields
|
from dataclasses import fields as dataclass_fields
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.config import (
|
from fastdeploy.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
|
ConvertOption,
|
||||||
EarlyStopConfig,
|
EarlyStopConfig,
|
||||||
FDConfig,
|
FDConfig,
|
||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
@@ -32,6 +33,8 @@ from fastdeploy.config import (
|
|||||||
MobaAttentionConfig,
|
MobaAttentionConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
|
PoolerConfig,
|
||||||
|
RunnerOption,
|
||||||
SpeculativeConfig,
|
SpeculativeConfig,
|
||||||
TaskOption,
|
TaskOption,
|
||||||
)
|
)
|
||||||
@@ -95,6 +98,20 @@ class EngineArgs:
|
|||||||
"""
|
"""
|
||||||
The task to be executed by the model.
|
The task to be executed by the model.
|
||||||
"""
|
"""
|
||||||
|
runner: RunnerOption = "auto"
|
||||||
|
"""
|
||||||
|
The type of model runner to use.Each FD instance only supports one model runner.
|
||||||
|
even if the same model can be used for multiple types.
|
||||||
|
"""
|
||||||
|
convert: ConvertOption = "auto"
|
||||||
|
"""
|
||||||
|
Convert the model using adapters. The most common use case is to
|
||||||
|
adapt a text generation model to be used for pooling tasks.
|
||||||
|
"""
|
||||||
|
override_pooler_config: Optional[Union[dict, PoolerConfig]] = None
|
||||||
|
"""
|
||||||
|
Override configuration for the pooler.
|
||||||
|
"""
|
||||||
max_num_seqs: int = 8
|
max_num_seqs: int = 8
|
||||||
"""
|
"""
|
||||||
Maximum number of sequences per iteration.
|
Maximum number of sequences per iteration.
|
||||||
@@ -473,6 +490,21 @@ class EngineArgs:
|
|||||||
default=EngineArgs.task,
|
default=EngineArgs.task,
|
||||||
help="Task to be executed by the model.",
|
help="Task to be executed by the model.",
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--runner",
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.runner,
|
||||||
|
help="The type of model runner to use",
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--convert", type=str, default=EngineArgs.convert, help="Convert the model using adapters"
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--override-pooler-config",
|
||||||
|
type=json.loads,
|
||||||
|
default=EngineArgs.override_pooler_config,
|
||||||
|
help="Override the pooler configuration with a JSON string.",
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--use-warmup",
|
"--use-warmup",
|
||||||
type=int,
|
type=int,
|
||||||
|
@@ -498,6 +498,9 @@ class LLMEngine:
|
|||||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||||
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
||||||
f" --ips {ips}"
|
f" --ips {ips}"
|
||||||
|
f" --runner {self.cfg.model_config.runner}"
|
||||||
|
f" --convert {self.cfg.model_config.convert}"
|
||||||
|
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_append_flag = {
|
worker_append_flag = {
|
||||||
|
170
fastdeploy/engine/pooling_params.py
Normal file
170
fastdeploy/engine/pooling_params.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING, Annotated, Any, Optional
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
|
||||||
|
from fastdeploy.engine.sampling_params import RequestOutputKind
|
||||||
|
from fastdeploy.engine.tasks import PoolingTask
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastdeploy.config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingParams:
|
||||||
|
"""API parameters for pooling models.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
normalize: Whether to normalize the embeddings outputs.
|
||||||
|
dimensions: Reduce the dimensions of embeddings
|
||||||
|
if model support matryoshka representation.
|
||||||
|
activation: Whether to apply activation function to
|
||||||
|
the classification outputs.
|
||||||
|
softmax: Whether to apply softmax to the reward outputs.
|
||||||
|
step_tag_id: Step tag ID for process reward models to identify
|
||||||
|
specific steps in multi-step reasoning tasks.
|
||||||
|
returned_token_ids: List of token IDs to return rewards for,
|
||||||
|
used for fine-grained reward calculation.
|
||||||
|
task: Internal use only. Specifies the pooling task type
|
||||||
|
("embed" for embeddings, "encode" for reward models).
|
||||||
|
requires_token_ids: Internal use only. Whether token ID information
|
||||||
|
is required for processing.
|
||||||
|
extra_kwargs: Internal use only. Dictionary for storing additional
|
||||||
|
custom parameters for extended functionality.
|
||||||
|
output_kind: Output type specification, fixed to FINAL_ONLY
|
||||||
|
(only final outputs are returned).
|
||||||
|
"""
|
||||||
|
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None
|
||||||
|
"""If set to -1, will use the truncation size supported by the model. If
|
||||||
|
set to an integer k, will use only the last k tokens from the prompt
|
||||||
|
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||||
|
|
||||||
|
# for embeddings models
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
normalize: Optional[bool] = None
|
||||||
|
|
||||||
|
# for reward models
|
||||||
|
softmax: Optional[bool] = None
|
||||||
|
step_tag_id: Optional[int] = None
|
||||||
|
returned_token_ids: Optional[list[int]] = None
|
||||||
|
|
||||||
|
task: Optional[PoolingTask] = None
|
||||||
|
"""Internal use only."""
|
||||||
|
|
||||||
|
requires_token_ids: bool = False
|
||||||
|
"""Internal use only."""
|
||||||
|
|
||||||
|
extra_kwargs: Optional[dict[str, Any]] = None
|
||||||
|
"""Internal use only."""
|
||||||
|
|
||||||
|
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _all_parameters(self) -> list[str]:
|
||||||
|
return ["dimensions", "normalize", "softmax", "step_tag_id", "returned_token_ids"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid_parameters(self):
|
||||||
|
return {
|
||||||
|
"embed": ["dimensions", "normalize"],
|
||||||
|
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def clone(self) -> "PoolingParams":
|
||||||
|
"""Returns a deep copy of the PoolingParams instance."""
|
||||||
|
return deepcopy(self)
|
||||||
|
|
||||||
|
def verify(self, task: PoolingTask, model_config: Optional["ModelConfig"] = None) -> None:
|
||||||
|
|
||||||
|
if self.task is None:
|
||||||
|
self.task = task
|
||||||
|
elif self.task != task:
|
||||||
|
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# NOTE: Task validation needs to done against the model instance,
|
||||||
|
# which is not available in model config. So, it's not included
|
||||||
|
# in this method
|
||||||
|
|
||||||
|
self._merge_default_parameters(model_config)
|
||||||
|
self._set_default_parameters(model_config)
|
||||||
|
self._verify_valid_parameters()
|
||||||
|
|
||||||
|
def _merge_default_parameters(self, model_config: Optional["ModelConfig"] = None) -> None:
|
||||||
|
|
||||||
|
if model_config is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
pooler_config = model_config.pooler_config
|
||||||
|
if pooler_config is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert self.task is not None, "task must be set"
|
||||||
|
valid_parameters = self.valid_parameters[self.task]
|
||||||
|
|
||||||
|
for k in valid_parameters:
|
||||||
|
if getattr(pooler_config, k, None) is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if getattr(self, k, None) is None:
|
||||||
|
setattr(self, k, getattr(pooler_config, k))
|
||||||
|
|
||||||
|
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
|
||||||
|
if self.task == "embed":
|
||||||
|
if self.normalize is None:
|
||||||
|
self.normalize = True
|
||||||
|
elif self.task == "encode":
|
||||||
|
if self.softmax is None:
|
||||||
|
self.softmax = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown pooling task: {self.task}")
|
||||||
|
|
||||||
|
def _verify_valid_parameters(self):
|
||||||
|
assert self.task is not None, "task must be set"
|
||||||
|
valid_parameters = self.valid_parameters[self.task]
|
||||||
|
invalid_parameters = []
|
||||||
|
for k in self._all_parameters:
|
||||||
|
if k in valid_parameters:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if getattr(self, k, None) is not None:
|
||||||
|
invalid_parameters.append(k)
|
||||||
|
|
||||||
|
if invalid_parameters:
|
||||||
|
raise ValueError(
|
||||||
|
f"Task {self.task} only supports {valid_parameters} "
|
||||||
|
f"parameters, does not support "
|
||||||
|
f"{invalid_parameters} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"PoolingParams("
|
||||||
|
f"task={self.task}, "
|
||||||
|
f"normalize={self.normalize}, "
|
||||||
|
f"dimensions={self.dimensions}, "
|
||||||
|
f"softmax={self.softmax}, "
|
||||||
|
f"step_tag_id={self.step_tag_id}, "
|
||||||
|
f"returned_token_ids={self.returned_token_ids}, "
|
||||||
|
f"requires_token_ids={self.requires_token_ids}, "
|
||||||
|
f"extra_kwargs={self.extra_kwargs})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
assert self.output_kind == RequestOutputKind.FINAL_ONLY, "For pooling output_kind has to be FINAL_ONLY"
|
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
@@ -268,3 +269,12 @@ class GuidedDecodingParams:
|
|||||||
"You can only use one kind of guided decoding "
|
"You can only use one kind of guided decoding "
|
||||||
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
|
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestOutputKind(Enum):
|
||||||
|
# Return entire output so far in every RequestOutput
|
||||||
|
CUMULATIVE = 0
|
||||||
|
# Return only deltas in each RequestOutput
|
||||||
|
DELTA = 1
|
||||||
|
# Do not return intermediate RequestOutput
|
||||||
|
FINAL_ONLY = 2
|
||||||
|
25
fastdeploy/engine/tasks.py
Normal file
25
fastdeploy/engine/tasks.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Literal, get_args
|
||||||
|
|
||||||
|
GenerationTask = Literal["generate"]
|
||||||
|
GENERATION_TASKS = get_args(GenerationTask)
|
||||||
|
|
||||||
|
PoolingTask = Literal["encode", "embed"]
|
||||||
|
POOLING_TASKS = get_args(PoolingTask)
|
||||||
|
|
||||||
|
SupportedTask = Literal[GenerationTask, PoolingTask]
|
@@ -146,3 +146,26 @@ class SiluAndMul(nn.Layer):
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
out = out + self.bias
|
out = out + self.bias
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_act_fn(act_fn_name: str) -> nn.Layer:
|
||||||
|
"""Get an activation function by name."""
|
||||||
|
act_fn_name = act_fn_name.lower()
|
||||||
|
|
||||||
|
if act_fn_name.startswith("paddle.nn.Layer"):
|
||||||
|
activation_name = act_fn_name.split(".")[-1]
|
||||||
|
if activation_name == "identity":
|
||||||
|
return nn.Identity()
|
||||||
|
act_fn_name = activation_name
|
||||||
|
|
||||||
|
activation_map = {
|
||||||
|
"gelu": nn.GELU(),
|
||||||
|
"relu": nn.ReLU(),
|
||||||
|
"silu": nn.Silu(),
|
||||||
|
"tanh": nn.Tanh(),
|
||||||
|
"sigmoid": nn.Sigmoid(),
|
||||||
|
}
|
||||||
|
if act_fn_name in activation_map:
|
||||||
|
return activation_map[act_fn_name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
|
||||||
|
85
fastdeploy/model_executor/layers/pool/metadata.py
Normal file
85
fastdeploy/model_executor/layers/pool/metadata.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.engine.pooling_params import PoolingParams
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PoolingCursor:
|
||||||
|
index: list[int]
|
||||||
|
first_token_indices_gpu: paddle.Tensor
|
||||||
|
last_token_indices_gpu: paddle.Tensor
|
||||||
|
prompt_lens_cpu: paddle.Tensor
|
||||||
|
num_scheduled_tokens_cpu: paddle.Tensor
|
||||||
|
|
||||||
|
def __getitem__(self, indices: slice):
|
||||||
|
return PoolingCursor(
|
||||||
|
index=self.index[indices],
|
||||||
|
first_token_indices_gpu=self.first_token_indices_gpu[indices],
|
||||||
|
last_token_indices_gpu=self.last_token_indices_gpu[indices],
|
||||||
|
prompt_lens_cpu=self.prompt_lens_cpu[indices],
|
||||||
|
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_partial_prefill(self):
|
||||||
|
return not paddle.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu).item()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PoolingMetadata:
|
||||||
|
"""Tensors for pooling."""
|
||||||
|
|
||||||
|
prompt_lens: paddle.Tensor # CPU Tensor
|
||||||
|
prompt_token_ids: Optional[paddle.Tensor]
|
||||||
|
pooling_params: list[PoolingParams]
|
||||||
|
pooling_cursor: Optional[PoolingCursor] = None
|
||||||
|
|
||||||
|
def __getitem__(self, indices: slice):
|
||||||
|
return PoolingMetadata(
|
||||||
|
prompt_lens=self.prompt_lens[indices],
|
||||||
|
prompt_token_ids=None if self.prompt_token_ids is None else self.prompt_token_ids[indices],
|
||||||
|
pooling_params=self.pooling_params[indices],
|
||||||
|
pooling_cursor=None if self.pooling_cursor is None else self.pooling_cursor[indices],
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_pooling_cursor(self, num_scheduled_tokens: list[int], device: str):
|
||||||
|
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, self.prompt_lens, device)
|
||||||
|
|
||||||
|
|
||||||
|
def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Tensor, device: str):
|
||||||
|
assert len(prompt_lens) == len(num_scheduled_tokens)
|
||||||
|
|
||||||
|
n_seq = len(num_scheduled_tokens)
|
||||||
|
index = list(range(n_seq))
|
||||||
|
num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens, device="cpu")
|
||||||
|
cumsum = paddle.zeros([n_seq + 1], dtype="int64", place=paddle.CPUPlace())
|
||||||
|
paddle.cumsum(num_scheduled_tokens, axis=0, out=cumsum[1:])
|
||||||
|
if device == "gpu":
|
||||||
|
cumsum_device = cumsum.cuda()
|
||||||
|
else:
|
||||||
|
cumsum_device = cumsum
|
||||||
|
return PoolingCursor(
|
||||||
|
index=index,
|
||||||
|
first_token_indices_gpu=cumsum_device[:n_seq],
|
||||||
|
last_token_indices_gpu=cumsum_device[1:] - 1,
|
||||||
|
prompt_lens_cpu=prompt_lens,
|
||||||
|
num_scheduled_tokens_cpu=num_scheduled_tokens,
|
||||||
|
)
|
550
fastdeploy/model_executor/layers/pooler.py
Normal file
550
fastdeploy/model_executor/layers/pooler.py
Normal file
@@ -0,0 +1,550 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Mapping, Set
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import IntEnum
|
||||||
|
from itertools import groupby
|
||||||
|
from typing import Callable, Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig, ModelConfig, PoolerConfig
|
||||||
|
from fastdeploy.engine.tasks import PoolingTask
|
||||||
|
from fastdeploy.model_executor.layers.pool.metadata import (
|
||||||
|
PoolingCursor,
|
||||||
|
PoolingMetadata,
|
||||||
|
PoolingParams,
|
||||||
|
)
|
||||||
|
from fastdeploy.model_executor.models.adapters import _load_st_projector
|
||||||
|
from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput
|
||||||
|
from fastdeploy.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("pooler", "pooler.log")
|
||||||
|
|
||||||
|
PoolingFn = Callable[
|
||||||
|
[Union[paddle.Tensor, list[paddle.Tensor]], PoolingMetadata], Union[paddle.Tensor, list[paddle.Tensor]]
|
||||||
|
]
|
||||||
|
ClassifierFn = Callable[[paddle.Tensor], paddle.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingType(IntEnum):
|
||||||
|
"""Enumeration for different types of pooling methods."""
|
||||||
|
|
||||||
|
LAST = 0
|
||||||
|
ALL = 1
|
||||||
|
CLS = 2
|
||||||
|
STEP = 3
|
||||||
|
MEAN = 4
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar("_T", paddle.Tensor, list[paddle.Tensor])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ResolvedPoolingConfig:
|
||||||
|
pooling_type: PoolingType
|
||||||
|
task: PoolingTask
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
task: PoolingTask,
|
||||||
|
pooler_config: PoolerConfig,
|
||||||
|
) -> "ResolvedPoolingConfig":
|
||||||
|
assert pooler_config.pooling_type is not None
|
||||||
|
return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
|
||||||
|
|
||||||
|
|
||||||
|
def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
|
||||||
|
pooling_params = pooling_metadata.pooling_params
|
||||||
|
return pooling_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
|
||||||
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
|
|
||||||
|
tasks: list[PoolingTask] = [task for pooling_param in pooling_params if (task := pooling_param.task) is not None]
|
||||||
|
assert len(pooling_params) == len(tasks)
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[paddle.Tensor]:
|
||||||
|
assert (
|
||||||
|
pooling_metadata.prompt_token_ids is not None
|
||||||
|
), "Please set `requires_token_ids=True` in `get_pooling_updates`"
|
||||||
|
|
||||||
|
return [pooling_metadata.prompt_token_ids[i, :num] for i, num in enumerate(pooling_metadata.prompt_lens)]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PoolingParamsUpdate:
|
||||||
|
requires_token_ids: bool = False
|
||||||
|
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
|
||||||
|
|
||||||
|
def apply(self, params: PoolingParams) -> None:
|
||||||
|
params.requires_token_ids = self.requires_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
class Pooler(nn.Layer, ABC):
|
||||||
|
"""The interface required for all poolers used in pooling models in FastDeploy."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def for_encode(pooler_config: PoolerConfig, model_config: Optional["ModelConfig"] = None):
|
||||||
|
if pooler_config.pooling_type == "STEP":
|
||||||
|
return StepPooler()
|
||||||
|
|
||||||
|
resolved_config = ResolvedPoolingConfig(task="encode", pooling_type=PoolingType.ALL)
|
||||||
|
return SimplePooler.from_config(resolved_config, model_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def for_embed(pooler_config: PoolerConfig, model_config: Optional["ModelConfig"] = None):
|
||||||
|
resolved_config = ResolvedPoolingConfig.from_config(
|
||||||
|
task="embed",
|
||||||
|
pooler_config=pooler_config,
|
||||||
|
)
|
||||||
|
return SimplePooler.from_config(resolved_config, model_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def for_classify(
|
||||||
|
pooler_config: PoolerConfig,
|
||||||
|
classify: Optional[ClassifierFn],
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
"""Determine which pooling tasks are supported."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
"""
|
||||||
|
Construct the updated pooling parameters to use for a supported task.
|
||||||
|
"""
|
||||||
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Union[list[paddle.Tensor], paddle.Tensor],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class BasePoolerActication(nn.Layer, ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, pooled_data: _T) -> _T:
|
||||||
|
# shape:
|
||||||
|
# classify (& score) -> (batch_size, num_classes)
|
||||||
|
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
|
||||||
|
# (batch_size, dimensions) or list(dimensions) if using MRL
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerActivation(BasePoolerActication):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wraps(module: nn.Layer):
|
||||||
|
if isinstance(module, nn.Identity):
|
||||||
|
return PoolerIdentity()
|
||||||
|
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
|
||||||
|
return PoolerClassify()
|
||||||
|
|
||||||
|
return LambdaPoolerActivation(module)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, pooled_data: _T) -> _T:
|
||||||
|
if isinstance(pooled_data, list):
|
||||||
|
return [self.forward_chunk(data) for data in pooled_data]
|
||||||
|
|
||||||
|
return self.forward_chunk(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerIdentity(PoolerActivation):
|
||||||
|
|
||||||
|
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerClassify(PoolerActivation):
|
||||||
|
|
||||||
|
def __init__(self, *, static_num_labels: bool = True) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if static_num_labels:
|
||||||
|
fd_config = FDConfig()
|
||||||
|
self.num_labels = getattr(fd_config.model_config, "num_labels", 0)
|
||||||
|
if self.num_labels == 0:
|
||||||
|
logger.warning(
|
||||||
|
"num_labels should be > 0 for classification"
|
||||||
|
"models, falling back to softmax. "
|
||||||
|
"Please check if the configuration is correct."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.num_labels = None
|
||||||
|
|
||||||
|
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
num_labels = self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
|
||||||
|
if num_labels < 2:
|
||||||
|
return F.sigmoid(pooled_data.astype("float32")).astype(pooled_data.dtype)
|
||||||
|
|
||||||
|
return F.softmax(pooled_data.astype("float32"), axis=-1).astype(pooled_data.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaPoolerActivation(PoolerActivation):
|
||||||
|
|
||||||
|
def __init__(self, fn: Callable[[paddle.Tensor], paddle.Tensor]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
return self.fn(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerHead(nn.Layer):
|
||||||
|
|
||||||
|
def __init__(self, activation: PoolerActivation) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata):
|
||||||
|
|
||||||
|
return self.activation(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingPoolerHead(PoolerHead):
|
||||||
|
|
||||||
|
def __init__(self, model_config: Optional["ModelConfig"] = None) -> None:
|
||||||
|
super().__init__(activation=PoolerNormalize())
|
||||||
|
|
||||||
|
self.projector = _load_st_projector(model_config)
|
||||||
|
|
||||||
|
def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata):
|
||||||
|
|
||||||
|
if isinstance(pooled_data, list):
|
||||||
|
pooled_data = paddle.stack(pooled_data)
|
||||||
|
# pooled_data shape: [batchsize, hidden_dimension]
|
||||||
|
|
||||||
|
# Apply ST projector
|
||||||
|
if self.projector is not None:
|
||||||
|
projector = cast(nn.Layer, self.projector)
|
||||||
|
|
||||||
|
def _proj(x: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
y = projector(x.astype("float32"))
|
||||||
|
return y.astype(orig_dtype)
|
||||||
|
|
||||||
|
pooled_data = _proj(pooled_data)
|
||||||
|
# pooled_data shape: [batchsize, embedding_dimension]
|
||||||
|
|
||||||
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
|
|
||||||
|
# for matryoshka representation
|
||||||
|
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
|
||||||
|
if any(d is not None for d in dimensions_list):
|
||||||
|
# change the output dimension
|
||||||
|
assert len(pooled_data) == len(dimensions_list)
|
||||||
|
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
|
||||||
|
# if all dimensions are the same
|
||||||
|
d = dimensions_list[0]
|
||||||
|
pooled_data = pooled_data[..., :d]
|
||||||
|
else:
|
||||||
|
pooled_data = [vecs if d is None else vecs[..., :d] for vecs, d in zip(pooled_data, dimensions_list)]
|
||||||
|
# for normalize
|
||||||
|
flags = [p.normalize for p in pooling_params]
|
||||||
|
if len(set(flags)) == 1:
|
||||||
|
if flags[0]:
|
||||||
|
pooled_data = self.activation(pooled_data)
|
||||||
|
else:
|
||||||
|
pooled_data = [self.activation(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)]
|
||||||
|
|
||||||
|
# pooled_data shape: [batchsize, embedding_dimension]
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
|
class RewardPoolerHead(PoolerHead):
|
||||||
|
|
||||||
|
def __init__(self, model_config: Optional["ModelConfig"] = None) -> None:
|
||||||
|
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
||||||
|
self.model_config = model_config
|
||||||
|
|
||||||
|
def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata):
|
||||||
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
|
|
||||||
|
# for softmax
|
||||||
|
flags = [p.softmax for p in pooling_params]
|
||||||
|
if len(set(flags)) == 1:
|
||||||
|
if flags[0]:
|
||||||
|
pooled_data = self.activation(pooled_data)
|
||||||
|
else:
|
||||||
|
pooled_data = [self.activation(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)]
|
||||||
|
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
|
def build_output(
|
||||||
|
all_data: Union[paddle.Tensor, list[paddle.Tensor]],
|
||||||
|
) -> PoolerOutput:
|
||||||
|
# Pooling models D2H & synchronize occurs here
|
||||||
|
if isinstance(all_data, list):
|
||||||
|
all_data = [d.cpu() for d in all_data]
|
||||||
|
else:
|
||||||
|
all_data = all_data.cpu()
|
||||||
|
|
||||||
|
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
|
||||||
|
return PoolerOutput(outputs=all_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingMethod(nn.Layer, ABC):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
|
||||||
|
if pooling_type == PoolingType.LAST:
|
||||||
|
return LastPool()
|
||||||
|
if pooling_type == PoolingType.ALL:
|
||||||
|
return AllPool()
|
||||||
|
if pooling_type == PoolingType.CLS:
|
||||||
|
return CLSPool()
|
||||||
|
if pooling_type == PoolingType.MEAN:
|
||||||
|
return MeanPool()
|
||||||
|
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
||||||
|
|
||||||
|
|
||||||
|
class LastPool(PoolingMethod):
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"encode", "embed", "classify", "score"}
|
||||||
|
|
||||||
|
def forward_all(
|
||||||
|
self,
|
||||||
|
hidden_states: paddle.Tensor,
|
||||||
|
pooling_cursor: PoolingCursor,
|
||||||
|
) -> Union[list[paddle.Tensor], paddle.Tensor]:
|
||||||
|
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||||
|
|
||||||
|
|
||||||
|
class AllPool(PoolingMethod):
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"encode"}
|
||||||
|
|
||||||
|
def forward_all(
|
||||||
|
self,
|
||||||
|
hidden_states: paddle.Tensor,
|
||||||
|
pooling_cursor: PoolingCursor,
|
||||||
|
) -> Union[list[paddle.Tensor], paddle.Tensor]:
|
||||||
|
|
||||||
|
assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with ALL pooling"
|
||||||
|
|
||||||
|
hidden_states_lst = list(hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()))
|
||||||
|
return [hidden_states_lst[i] for i in pooling_cursor.index]
|
||||||
|
|
||||||
|
|
||||||
|
class MeanPool(PoolingMethod):
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"encode", "embed", "classify", "score"}
|
||||||
|
|
||||||
|
def forward_all(
|
||||||
|
self,
|
||||||
|
hidden_states: paddle.Tensor,
|
||||||
|
pooling_cursor: PoolingCursor,
|
||||||
|
) -> Union[list[paddle.Tensor], paddle.Tensor]:
|
||||||
|
|
||||||
|
assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with MEAN pooling"
|
||||||
|
|
||||||
|
if hidden_states.place.is_gpu_place():
|
||||||
|
prompt_lens = pooling_cursor.prompt_lens_cpu.cuda()
|
||||||
|
else:
|
||||||
|
prompt_lens = pooling_cursor.prompt_lens_cpu
|
||||||
|
|
||||||
|
# Use float32 for paddle.cumsum in MeanPool,
|
||||||
|
# otherwise precision will be lost significantly.
|
||||||
|
cumsum = paddle.cumsum(hidden_states.astype("float32"), axis=0)
|
||||||
|
|
||||||
|
start_indices = pooling_cursor.first_token_indices_gpu
|
||||||
|
end_indices = pooling_cursor.last_token_indices_gpu
|
||||||
|
return (cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class CLSPool(PoolingMethod):
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"encode", "embed", "classify", "score"}
|
||||||
|
|
||||||
|
def forward_all(
|
||||||
|
self,
|
||||||
|
hidden_states: paddle.Tensor,
|
||||||
|
pooling_cursor: PoolingCursor,
|
||||||
|
) -> Union[list[paddle.Tensor], paddle.Tensor]:
|
||||||
|
assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with CLS pooling"
|
||||||
|
|
||||||
|
return hidden_states[pooling_cursor.first_token_indices_gpu]
|
||||||
|
|
||||||
|
|
||||||
|
class StepPooler(Pooler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pooling = AllPool()
|
||||||
|
self.head = RewardPoolerHead()
|
||||||
|
|
||||||
|
def extract_states(
|
||||||
|
self,
|
||||||
|
hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Union[list[paddle.Tensor], paddle.Tensor]:
|
||||||
|
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||||
|
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||||
|
|
||||||
|
pooled_data = list[paddle.Tensor]()
|
||||||
|
|
||||||
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
|
|
||||||
|
for data, token_id, pooling_param in zip(pooled_data_lst, prompt_token_ids, pooling_params):
|
||||||
|
step_tag_id = pooling_param.step_tag_id
|
||||||
|
returned_token_ids = pooling_param.returned_token_ids
|
||||||
|
|
||||||
|
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||||
|
data = data[:, returned_token_ids]
|
||||||
|
|
||||||
|
if step_tag_id is not None:
|
||||||
|
data = data[token_id == step_tag_id]
|
||||||
|
pooled_data.append(data)
|
||||||
|
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"encode"}
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
return PoolingParamsUpdate(requires_token_ids=True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||||
|
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||||
|
return build_output(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
|
class SimplePooler(Pooler):
|
||||||
|
"""A layer that pools specific information from hidden states.
|
||||||
|
|
||||||
|
This layer does the following:
|
||||||
|
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||||
|
2. Normalizes output if specified.
|
||||||
|
3. Returns structured results as `PoolerOutput`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
pooler_config: ResolvedPoolingConfig,
|
||||||
|
model_config: Optional["ModelConfig"] = None,
|
||||||
|
) -> "SimplePooler":
|
||||||
|
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
|
||||||
|
if pooler_config.task == "embed":
|
||||||
|
head = EmbeddingPoolerHead(model_config)
|
||||||
|
elif pooler_config.task == "encode":
|
||||||
|
head = RewardPoolerHead(model_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
|
||||||
|
return cls(pooling, head)
|
||||||
|
|
||||||
|
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pooling = pooling
|
||||||
|
self.head = head
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return self.pooling.get_supported_tasks()
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||||
|
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||||
|
return build_output(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerNormalize(PoolerActivation):
|
||||||
|
def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
x = F.normalize(pooled_data.astype("float32"), p=2, axis=-1)
|
||||||
|
return x.astype(pooled_data.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class DispatchPooler(Pooler):
|
||||||
|
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
||||||
|
|
||||||
|
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
for task, pooler in poolers_by_task.items():
|
||||||
|
if task not in pooler.get_supported_tasks():
|
||||||
|
raise ValueError(
|
||||||
|
f"{pooler=} does not support {task=}. " f"Supported tasks: {pooler.get_supported_tasks()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.poolers_by_task = poolers_by_task
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return set(self.poolers_by_task)
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
return self.poolers_by_task[task].get_pooling_updates(task)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
poolers_by_task = self.poolers_by_task
|
||||||
|
|
||||||
|
outputs = list[PoolingSequenceGroupOutput]()
|
||||||
|
offset = 0
|
||||||
|
for task, group in groupby(get_tasks(pooling_metadata)):
|
||||||
|
if not (pooler := poolers_by_task.get(task)):
|
||||||
|
raise ValueError(f"Unsupported task: {task} " f"Supported tasks: {self.get_supported_tasks()}")
|
||||||
|
|
||||||
|
num_items = len(list(group))
|
||||||
|
group_output: PoolerOutput = pooler(
|
||||||
|
hidden_states,
|
||||||
|
pooling_metadata[offset : offset + num_items],
|
||||||
|
)
|
||||||
|
outputs.extend(group_output.outputs)
|
||||||
|
offset += num_items
|
||||||
|
|
||||||
|
return PoolerOutput(outputs)
|
@@ -61,6 +61,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
fd_config,
|
fd_config,
|
||||||
return_numpy=True,
|
return_numpy=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.set_state_dict(state_dict)
|
model.set_state_dict(state_dict)
|
||||||
self.clean_memory_fragments(state_dict)
|
self.clean_memory_fragments(state_dict)
|
||||||
|
|
||||||
|
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddleformers.utils.log import logger
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
|
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
|
||||||
from fastdeploy.model_executor.load_weight_utils import (
|
from fastdeploy.model_executor.load_weight_utils import (
|
||||||
@@ -27,6 +27,7 @@ from fastdeploy.model_executor.load_weight_utils import (
|
|||||||
save_model,
|
save_model,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
|
from fastdeploy.model_executor.models.adapters import as_embedding_model
|
||||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
@@ -54,11 +55,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
|||||||
load_weights_form_cache(model, weights_iterator)
|
load_weights_form_cache(model, weights_iterator)
|
||||||
else:
|
else:
|
||||||
model.load_weights(weights_iterator)
|
model.load_weights(weights_iterator)
|
||||||
|
|
||||||
self.clean_memory_fragments()
|
self.clean_memory_fragments()
|
||||||
|
|
||||||
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||||
architectures = fd_config.model_config.architectures[0]
|
architectures = fd_config.model_config.architectures[0]
|
||||||
logger.info(f"Starting to load model {architectures}")
|
|
||||||
context = paddle.LazyGuard()
|
context = paddle.LazyGuard()
|
||||||
if fd_config.load_config.dynamic_load_weight:
|
if fd_config.load_config.dynamic_load_weight:
|
||||||
# register rl model
|
# register rl model
|
||||||
@@ -70,6 +71,14 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
|||||||
with weight_cache_context:
|
with weight_cache_context:
|
||||||
with context:
|
with context:
|
||||||
model_cls = ModelRegistry.get_class(architectures)
|
model_cls = ModelRegistry.get_class(architectures)
|
||||||
|
convert_type = fd_config.model_config.convert_type
|
||||||
|
if convert_type == "none":
|
||||||
|
pass
|
||||||
|
elif convert_type == "embed":
|
||||||
|
model_cls = as_embedding_model(model_cls)
|
||||||
|
else:
|
||||||
|
assert_never(convert_type)
|
||||||
|
|
||||||
model = model_cls(fd_config)
|
model = model_cls(fd_config)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@@ -47,8 +47,10 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
|
|||||||
module = importlib.import_module(f"{register_path}.{module_file}")
|
module = importlib.import_module(f"{register_path}.{module_file}")
|
||||||
for attr_name in dir(module):
|
for attr_name in dir(module):
|
||||||
attr = getattr(module, attr_name)
|
attr = getattr(module, attr_name)
|
||||||
|
|
||||||
if inspect.isclass(attr) and issubclass(attr, ModelForCasualLM) and attr is not ModelForCasualLM:
|
if inspect.isclass(attr) and issubclass(attr, ModelForCasualLM) and attr is not ModelForCasualLM:
|
||||||
ModelRegistry.register_model_class(attr)
|
ModelRegistry.register_model_class(attr)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
inspect.isclass(attr)
|
inspect.isclass(attr)
|
||||||
and issubclass(attr, PretrainedModel)
|
and issubclass(attr, PretrainedModel)
|
||||||
@@ -56,6 +58,7 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
|
|||||||
and hasattr(attr, "arch_name")
|
and hasattr(attr, "arch_name")
|
||||||
):
|
):
|
||||||
ModelRegistry.register_pretrained_model(attr)
|
ModelRegistry.register_pretrained_model(attr)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(f"{module_file=} import error")
|
raise ImportError(f"{module_file=} import error")
|
||||||
|
|
||||||
|
214
fastdeploy/model_executor/models/adapters.py
Normal file
214
fastdeploy/model_executor/models/adapters.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional, TypeVar
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from fastdeploy.config import ModelConfig
|
||||||
|
from fastdeploy.model_executor.layers.activation import get_act_fn
|
||||||
|
from fastdeploy.model_executor.models.interfaces_base import is_pooling_model
|
||||||
|
from fastdeploy.transformer_utils.config import get_hf_file_to_dict
|
||||||
|
|
||||||
|
_T = TypeVar("_T", bound=type[nn.Layer])
|
||||||
|
|
||||||
|
_GENERATE_SUFFIXES = [
|
||||||
|
"ForCausalLM",
|
||||||
|
"ForConditionalGeneration",
|
||||||
|
"ChatModel",
|
||||||
|
"LMHeadModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dense_weights(linear: nn.Linear, folder: str, model_config: "ModelConfig") -> bool:
|
||||||
|
"""Load weights using vLLM's weight_loader pattern."""
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.utils import default_weight_loader
|
||||||
|
|
||||||
|
filename = "model.safetensors"
|
||||||
|
file_path = f"{folder}/{filename}" if folder else filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_bytes = get_hf_file_to_dict(file_path, model_config.model, model_config.revision)
|
||||||
|
if not file_bytes:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
if filename.endswith(".safetensors"):
|
||||||
|
import io
|
||||||
|
|
||||||
|
from safetensors.numpy import load as load_safetensors
|
||||||
|
|
||||||
|
numpy_tensors = load_safetensors(io.BytesIO(file_bytes))
|
||||||
|
for key, numpy_array in numpy_tensors.items():
|
||||||
|
state_dict[key] = paddle.to_tensor(numpy_array)
|
||||||
|
else:
|
||||||
|
import io
|
||||||
|
|
||||||
|
state_dict = paddle.load(io.BytesIO(file_bytes))
|
||||||
|
|
||||||
|
weight_keys = ["weight", "linear.weight", "dense.weight"]
|
||||||
|
|
||||||
|
for weight_key in weight_keys:
|
||||||
|
if weight_key in state_dict:
|
||||||
|
weight_loader = getattr(linear.weight, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(linear.weight, state_dict[weight_key].astype(paddle.float32))
|
||||||
|
bias_key = weight_key.replace("weight", "bias")
|
||||||
|
if linear.bias is not None and bias_key in state_dict:
|
||||||
|
bias_loader = getattr(linear.bias, "weight_loader", default_weight_loader)
|
||||||
|
bias_loader(linear.bias, state_dict[bias_key].astype(paddle.float32))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load :{e}")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Layer]:
|
||||||
|
try:
|
||||||
|
modules = get_hf_file_to_dict("modules.json", model_config.model, model_config.revision)
|
||||||
|
if not modules:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(modules, dict):
|
||||||
|
modules = modules.get("modules", [])
|
||||||
|
|
||||||
|
dense_modules = [m for m in modules if m.get("type") == "sentence_transformers.models.Dense"]
|
||||||
|
if not dense_modules:
|
||||||
|
return None
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for module in dense_modules:
|
||||||
|
folder = module.get("path", "")
|
||||||
|
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||||
|
layer_config = get_hf_file_to_dict(config_path, model_config.model, model_config.revision)
|
||||||
|
if not layer_config:
|
||||||
|
continue
|
||||||
|
linear = nn.Linear(
|
||||||
|
layer_config.get("in_features", 768),
|
||||||
|
layer_config.get("out_features", 768),
|
||||||
|
bias=layer_config.get("bias", True),
|
||||||
|
)
|
||||||
|
linear = linear.astype(paddle.float32)
|
||||||
|
|
||||||
|
if not _load_dense_weights(linear, folder, model_config):
|
||||||
|
continue
|
||||||
|
|
||||||
|
layers.append(linear)
|
||||||
|
if act_name := layer_config.get("activation_function"):
|
||||||
|
layers.append(get_act_fn(act_name))
|
||||||
|
return nn.Sequential(*layers).astype(paddle.float32)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ST projector loading failed:{e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||||
|
|
||||||
|
class ModelForPooling(orig_cls):
|
||||||
|
|
||||||
|
def __init__(self, fd_config, *args, **kwargs):
|
||||||
|
super().__init__(fd_config, *args, **kwargs)
|
||||||
|
self.fd_config = fd_config
|
||||||
|
self.is_pooling_model = True
|
||||||
|
|
||||||
|
# These are not used in pooling models
|
||||||
|
for attr in ("lm_head", "logits_processor"):
|
||||||
|
if hasattr(self, attr):
|
||||||
|
delattr(self, attr)
|
||||||
|
|
||||||
|
# If the model already defines a pooler instance, don't overwrite it
|
||||||
|
if not getattr(self, "pooler", None):
|
||||||
|
self._init_pooler(fd_config)
|
||||||
|
|
||||||
|
def _init_pooler(self, fd_config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, paddle.Tensor]]):
|
||||||
|
# TODO: Support uninitialized params tracking
|
||||||
|
|
||||||
|
# We have deleted this attribute, so don't load it
|
||||||
|
weights = ((name, data) for name, data in weights if not name.startswith("lm_head."))
|
||||||
|
|
||||||
|
# If `*ForCausalLM` defines `load_weights` on the inner model
|
||||||
|
# and there are no other inner modules with parameters,
|
||||||
|
# we support loading from both `*Model` and `*ForCausalLM`
|
||||||
|
|
||||||
|
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
|
||||||
|
# Whether only `self.model` contains parameters
|
||||||
|
model_is_only_param = all(
|
||||||
|
name == "model" or not any(child.parameters()) for name, child in self.named_children()
|
||||||
|
)
|
||||||
|
if model_is_only_param:
|
||||||
|
weights = ((name[6:], data) for name, data in weights if name.startswith("model."))
|
||||||
|
loaded_params = self.model.load_weights(weights)
|
||||||
|
loaded_params = {f"model.{name}" for name in loaded_params}
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
# For most other models
|
||||||
|
if hasattr(orig_cls, "load_weights"):
|
||||||
|
return orig_cls.load_weights(self, weights) # type: ignore
|
||||||
|
# Fallback
|
||||||
|
else:
|
||||||
|
raise ValueError("No load_weights method found in the model.")
|
||||||
|
|
||||||
|
return ModelForPooling
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||||
|
model_name = orig_model_name
|
||||||
|
|
||||||
|
for generate_suffix in _GENERATE_SUFFIXES:
|
||||||
|
model_name = model_name.removesuffix(generate_suffix)
|
||||||
|
return model_name + pooling_suffix
|
||||||
|
|
||||||
|
|
||||||
|
def as_embedding_model(cls: _T) -> _T:
|
||||||
|
"""
|
||||||
|
Subclass an existing vLLM model to support embeddings.
|
||||||
|
|
||||||
|
By default, the embeddings of the whole prompt are extracted from the
|
||||||
|
normalized hidden state corresponding to the last token.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
We assume that no extra layers are added to the original model;
|
||||||
|
please implement your own model if this is not the case.
|
||||||
|
"""
|
||||||
|
# Avoid modifying existing embedding models
|
||||||
|
if is_pooling_model(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
|
|
||||||
|
class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
||||||
|
|
||||||
|
def _init_pooler(self, fd_config, prefix: str = ""):
|
||||||
|
pooler_config = fd_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
self.pooler = DispatchPooler(
|
||||||
|
{
|
||||||
|
"encode": Pooler.for_encode(pooler_config, fd_config.model_config),
|
||||||
|
"embed": Pooler.for_embed(pooler_config, fd_config.model_config),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||||
|
|
||||||
|
return ModelForEmbedding
|
@@ -48,7 +48,11 @@ from fastdeploy.model_executor.layers.normalization import RMSNorm
|
|||||||
from fastdeploy.model_executor.layers.rotary_embedding import (
|
from fastdeploy.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding,
|
DeepseekScalingRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
@@ -588,6 +592,12 @@ class DeepSeekV3Model(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="DeepseekV3ForCausalLM",
|
||||||
|
module_path="deepseek_v3",
|
||||||
|
category=ModelCategory.TEXT_GENERATION,
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class DeepseekV3ForCausalLM(ModelForCasualLM):
|
class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
DeepseekV3ForCausalLM
|
DeepseekV3ForCausalLM
|
||||||
|
@@ -45,7 +45,11 @@ from fastdeploy.model_executor.layers.linear import (
|
|||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||||
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
||||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||||
@@ -471,6 +475,12 @@ class Ernie4_5_Model(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Ernie4_5_MoeForCausalLM",
|
||||||
|
module_path="ernie4_5_moe",
|
||||||
|
category=ModelCategory.TEXT_GENERATION,
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
Ernie4_5_MoeForCausalLM
|
Ernie4_5_MoeForCausalLM
|
||||||
@@ -646,6 +656,12 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Ernie4_5_ForCausalLM",
|
||||||
|
module_path="ernie4_5_moe",
|
||||||
|
category=ModelCategory.TEXT_GENERATION,
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
||||||
"""
|
"""
|
||||||
Ernie4_5_ForCausalLM
|
Ernie4_5_ForCausalLM
|
||||||
@@ -659,6 +675,12 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
|||||||
return "Ernie4_5_ForCausalLM"
|
return "Ernie4_5_ForCausalLM"
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Ernie4_5ForCausalLM",
|
||||||
|
module_path="ernie4_5_moe",
|
||||||
|
category=ModelCategory.TEXT_GENERATION,
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class Ernie4_5ForCausalLM(Ernie4_5_ForCausalLM):
|
class Ernie4_5ForCausalLM(Ernie4_5_ForCausalLM):
|
||||||
"""
|
"""
|
||||||
Ernie4_5ForCausalLM 0.3B-PT
|
Ernie4_5ForCausalLM 0.3B-PT
|
||||||
|
@@ -31,7 +31,11 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta
|
|||||||
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
|
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
|
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
||||||
@@ -325,6 +329,12 @@ class Ernie4_5_MTPModel(nn.Layer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Ernie4_5_MTPForCausalLM",
|
||||||
|
module_path="ernie4_5_mtp",
|
||||||
|
category=ModelCategory.TEXT_GENERATION,
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
Ernie4_5_MTPForCausalLM
|
Ernie4_5_MTPForCausalLM
|
||||||
|
@@ -44,7 +44,11 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (
|
|||||||
Ernie4_5_Attention,
|
Ernie4_5_Attention,
|
||||||
Ernie4_5_MLP,
|
Ernie4_5_MLP,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
@@ -792,6 +796,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Ernie4_5_VLMoeForConditionalGeneration",
|
||||||
|
module_path="ernie4_5_vl.ernie4_5_vl_moe",
|
||||||
|
category=ModelCategory.MULTIMODAL,
|
||||||
|
primary_use=ModelCategory.MULTIMODAL,
|
||||||
|
)
|
||||||
class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||||
"""
|
"""
|
||||||
Ernie4_5_MoePretrainedModel
|
Ernie4_5_MoePretrainedModel
|
||||||
|
@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
|
|||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Glm4MoeMLP(nn.Layer):
|
class Glm4MoeMLP(nn.Layer):
|
||||||
@@ -363,6 +367,12 @@ class Glm4MoeModel(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Glm4MoeForCausalLM",
|
||||||
|
module_path="glm4_moe",
|
||||||
|
category=ModelCategory.TEXT_GENERATION,
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class Glm4MoeForCausalLM(ModelForCasualLM):
|
class Glm4MoeForCausalLM(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
Glm4MoeForCausalLM
|
Glm4MoeForCausalLM
|
||||||
|
54
fastdeploy/model_executor/models/interfaces_base.py
Normal file
54
fastdeploy/model_executor/models/interfaces_base.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:
|
||||||
|
from .model_base import ModelForCasualLM
|
||||||
|
|
||||||
|
return issubclass(model_cls, ModelForCasualLM)
|
||||||
|
|
||||||
|
|
||||||
|
def is_pooling_model(model_cls: Type[nn.Layer]) -> bool:
|
||||||
|
class_name = model_cls.__name__
|
||||||
|
pooling_indicators = ["Embedding", "ForSequenceClassification"]
|
||||||
|
return (
|
||||||
|
any(indicator in class_name for indicator in pooling_indicators)
|
||||||
|
or hasattr(model_cls, "is_embedding_model")
|
||||||
|
and model_cls.is_embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_multimodal_model(class_name: str) -> bool:
|
||||||
|
multimodal_indicators = ["VL", "Vision", "ConditionalGeneration"]
|
||||||
|
return any(indicator in class_name for indicator in multimodal_indicators)
|
||||||
|
|
||||||
|
|
||||||
|
def determine_model_category(class_name: str):
|
||||||
|
from fastdeploy.model_executor.models.model_base import ModelCategory
|
||||||
|
|
||||||
|
if any(pattern in class_name for pattern in ["VL", "Vision", "ConditionalGeneration"]):
|
||||||
|
return ModelCategory.MULTIMODAL
|
||||||
|
elif any(pattern in class_name for pattern in ["Embedding", "ForSequenceClassification"]):
|
||||||
|
return ModelCategory.EMBEDDING
|
||||||
|
return ModelCategory.TEXT_GENERATION
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_pooling_type(model_cls: Type[nn.Layer] = None) -> str:
|
||||||
|
if model_cls is not None:
|
||||||
|
return getattr(model_cls, "default_pooling_type", "LAST")
|
||||||
|
return "LAST"
|
@@ -3,40 +3,269 @@
|
|||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
# You may obtain a copy of the License at
|
@@ -12,31 +11,265 @@
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Union
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddleformers.transformers import PretrainedModel
|
from paddleformers.transformers import PretrainedModel
|
||||||
|
|
||||||
|
from fastdeploy.config import (
|
||||||
|
ModelConfig,
|
||||||
|
iter_architecture_defaults,
|
||||||
|
try_match_architecture_defaults,
|
||||||
|
)
|
||||||
|
from fastdeploy.model_executor.models.interfaces_base import (
|
||||||
|
determine_model_category,
|
||||||
|
get_default_pooling_type,
|
||||||
|
is_multimodal_model,
|
||||||
|
is_pooling_model,
|
||||||
|
is_text_generation_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCategory(Enum):
|
||||||
|
TEXT_GENERATION = "text_generation"
|
||||||
|
MULTIMODAL = "multimodal"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelInfo:
|
||||||
|
architecture: str
|
||||||
|
category: ModelCategory
|
||||||
|
is_text_generation: bool
|
||||||
|
is_multimodal: bool
|
||||||
|
is_pooling: bool
|
||||||
|
module_path: str
|
||||||
|
default_pooling_type: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_cls(model_cls: Type[nn.Layer], module_path: str = "") -> "ModelInfo":
|
||||||
|
return ModelInfo(
|
||||||
|
architecture=model_cls.__name__,
|
||||||
|
category=determine_model_category(model_cls.__name__),
|
||||||
|
is_text_generation=is_text_generation_model(model_cls),
|
||||||
|
is_multimodal=is_multimodal_model(model_cls.__name__),
|
||||||
|
is_pooling=is_pooling_model(model_cls),
|
||||||
|
default_pooling_type=get_default_pooling_type(model_cls),
|
||||||
|
module_path=module_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRegisteredModel(ABC):
|
||||||
|
"""Base class for registered models"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_cls(self) -> Type[nn.Layer]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def inspect_model_cls(self) -> ModelInfo:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LazyRegisteredModel(BaseRegisteredModel):
|
||||||
|
"""Lazy loaded model"""
|
||||||
|
|
||||||
|
module_name: str
|
||||||
|
class_name: str
|
||||||
|
|
||||||
|
def load_model_cls(self) -> Type[nn.Layer]:
|
||||||
|
try:
|
||||||
|
full_module = f"fastdeploy.model_executor.models.{self.module_name}"
|
||||||
|
module = importlib.import_module(full_module)
|
||||||
|
return getattr(module, self.class_name)
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise ImportError(f"Failed to load {self.class_name}: {e}")
|
||||||
|
|
||||||
|
def inspect_model_cls(self) -> ModelInfo:
|
||||||
|
model_cls = self.load_model_cls()
|
||||||
|
return ModelInfo.from_model_cls(model_cls, self.module_name)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RegisteredModel(BaseRegisteredModel):
|
||||||
|
|
||||||
|
model_cls: Type[nn.Layer]
|
||||||
|
|
||||||
|
def load_model_cls(self) -> Type[nn.Layer]:
|
||||||
|
return self.model_cls
|
||||||
|
|
||||||
|
def inspect_model_cls(self) -> ModelInfo:
|
||||||
|
return ModelInfo.from_model_cls(self.model_cls)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _try_inspect_model_cls(
|
||||||
|
model_arch: str,
|
||||||
|
model: BaseRegisteredModel,
|
||||||
|
) -> Optional[ModelInfo]:
|
||||||
|
try:
|
||||||
|
return model.inspect_model_cls()
|
||||||
|
except Exception:
|
||||||
|
print("Error in inspecting model architecture '%s'", model_arch)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistry:
|
class ModelRegistry:
|
||||||
"""
|
|
||||||
Used to register and retrieve model classes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_arch_to_model_cls = {}
|
_arch_to_model_cls = {}
|
||||||
_arch_to_pretrained_model_cls = {}
|
_arch_to_pretrained_model_cls = {}
|
||||||
|
_enhanced_models: Dict[str, Dict] = {}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.models: Dict[str, BaseRegisteredModel] = {}
|
||||||
|
self.pretrained_models: Dict[str, Type[PretrainedModel]] = {}
|
||||||
|
self._registered_models: Dict[str, BaseRegisteredModel] = {}
|
||||||
|
self._register_enhanced_models()
|
||||||
|
|
||||||
|
def _register_enhanced_models(self):
|
||||||
|
for arch, model_info in self._enhanced_models.items():
|
||||||
|
model = LazyRegisteredModel(module_name=model_info["module_path"], class_name=model_info["class_name"])
|
||||||
|
self.models[arch] = model
|
||||||
|
self._registered_models[arch] = model
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _try_load_model_cls(self, architecture: str) -> Optional[Type[nn.Layer]]:
|
||||||
|
if architecture not in self.models:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return self.models[architecture].load_model_cls()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load model {architecture}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _try_inspect_model_cls(self, model_arch: str) -> Optional[ModelInfo]:
|
||||||
|
if model_arch not in self.models:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return self.models[model_arch].inspect_model_cls()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to inspect model {model_arch}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _normalize_arch(self, architecture: str, model_config: ModelConfig) -> str:
|
||||||
|
if architecture in self.models:
|
||||||
|
return architecture
|
||||||
|
|
||||||
|
match = try_match_architecture_defaults(
|
||||||
|
architecture,
|
||||||
|
runner_type=getattr(model_config, "runner_type", None),
|
||||||
|
convert_type=getattr(model_config, "convert_type", None),
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
suffix, _ = match
|
||||||
|
for repl_suffix, _ in iter_architecture_defaults():
|
||||||
|
base_arch = architecture.replace(suffix, repl_suffix)
|
||||||
|
if base_arch in self.models:
|
||||||
|
return base_arch
|
||||||
|
|
||||||
|
return architecture
|
||||||
|
|
||||||
|
def _raise_for_unsupported(self, architectures: list[str]):
|
||||||
|
all_supported_archs = self.get_supported_archs()
|
||||||
|
|
||||||
|
if any(arch in all_supported_archs for arch in architectures):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} failed to be inspected. "
|
||||||
|
"Please check the logs for more details."
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {all_supported_archs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def inspect_model_cls(
|
||||||
|
self, architectures: Union[str, List[str]], model_config: ModelConfig = None
|
||||||
|
) -> Tuple[ModelInfo, str]:
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
|
||||||
|
if not architectures:
|
||||||
|
raise ValueError("No model architectures are specified")
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
normalized_arch = self._normalize_arch(arch, model_config)
|
||||||
|
model_info = self._try_inspect_model_cls(normalized_arch)
|
||||||
|
if model_info is not None:
|
||||||
|
return (model_info, arch)
|
||||||
|
|
||||||
|
return self._raise_for_unsupported(architectures)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_model_class(cls, model_class):
|
def register_model_class(
|
||||||
"""register model class"""
|
cls,
|
||||||
if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM:
|
model_class=None,
|
||||||
cls._arch_to_model_cls[model_class.name()] = model_class
|
*,
|
||||||
return model_class
|
architecture: str = None,
|
||||||
|
module_path: str = None,
|
||||||
|
category: Union[ModelCategory, List[ModelCategory]] = ModelCategory.TEXT_GENERATION,
|
||||||
|
primary_use: ModelCategory = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Enhanced model class registration supporting both traditional and decorator-style registration.
|
||||||
|
|
||||||
|
Can be used as:
|
||||||
|
1. Traditional decorator: @ModelRegistry.register_model_class
|
||||||
|
2. Enhanced decorator with metadata: @ModelRegistry.register_model_class(architecture="...", module_path="...")
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_class: The model class (when used as simple decorator)
|
||||||
|
architecture (str): Unique identifier for the model architecture
|
||||||
|
module_path (str): Relative path to the module containing the model
|
||||||
|
category: Model category or list of categories
|
||||||
|
primary_use: Primary category for multi-category models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _register(model_cls):
|
||||||
|
# Traditional registration for ModelForCasualLM subclasses
|
||||||
|
if issubclass(model_cls, ModelForCasualLM) and model_cls is not ModelForCasualLM:
|
||||||
|
cls._arch_to_model_cls[model_cls.name()] = model_cls
|
||||||
|
|
||||||
|
# Enhanced decorator-style registration
|
||||||
|
if architecture and module_path:
|
||||||
|
categories = category if isinstance(category, list) else [category]
|
||||||
|
|
||||||
|
# Register main entry
|
||||||
|
arch_key = architecture
|
||||||
|
cls._enhanced_models[arch_key] = {
|
||||||
|
"class_name": model_cls.__name__,
|
||||||
|
"module_path": module_path,
|
||||||
|
"category": primary_use or categories[0],
|
||||||
|
"class": model_cls,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register category-specific entries for multi-category models
|
||||||
|
if len(categories) > 1:
|
||||||
|
for cat in categories:
|
||||||
|
key = f"{arch_key}_{cat.value}"
|
||||||
|
cls._enhanced_models[key] = {
|
||||||
|
"class_name": model_cls.__name__,
|
||||||
|
"module_path": module_path,
|
||||||
|
"category": cat,
|
||||||
|
"primary_use": primary_use or categories[0],
|
||||||
|
"class": model_cls,
|
||||||
|
}
|
||||||
|
return model_cls
|
||||||
|
|
||||||
|
if model_class is not None:
|
||||||
|
return _register(model_class)
|
||||||
|
else:
|
||||||
|
return _register
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_pretrained_model(cls, pretrained_model):
|
def register_pretrained_model(cls, pretrained_model):
|
||||||
@@ -50,11 +279,6 @@ class ModelRegistry:
|
|||||||
|
|
||||||
return pretrained_model
|
return pretrained_model
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_pretrain_cls(cls, architectures: str):
|
|
||||||
"""get_pretrain_cls"""
|
|
||||||
return cls._arch_to_pretrained_model_cls[architectures]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_class(cls, name):
|
def get_class(cls, name):
|
||||||
"""get model class"""
|
"""get model class"""
|
||||||
@@ -62,12 +286,61 @@ class ModelRegistry:
|
|||||||
raise ValueError(f"Model '{name}' is not registered!")
|
raise ValueError(f"Model '{name}' is not registered!")
|
||||||
return cls._arch_to_model_cls[name]
|
return cls._arch_to_model_cls[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pretrain_cls(cls, architectures: str):
|
||||||
|
"""get_pretrain_cls"""
|
||||||
|
return cls._arch_to_pretrained_model_cls[architectures]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_archs(cls):
|
def get_supported_archs(cls):
|
||||||
assert len(cls._arch_to_model_cls) >= len(
|
traditional_archs = list(cls._arch_to_model_cls.keys())
|
||||||
cls._arch_to_pretrained_model_cls
|
enhanced_archs = list(cls._enhanced_models.keys())
|
||||||
), "model class num is more than pretrained model registry num"
|
return traditional_archs + enhanced_archs
|
||||||
return [key for key in cls._arch_to_model_cls.keys()]
|
|
||||||
|
def resolve_model_cls(self, architectures: Union[str, List[str]]) -> Tuple[Type[nn.Layer], str]:
|
||||||
|
"""Resolve model class"""
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = self._try_load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return model_cls, arch
|
||||||
|
|
||||||
|
raise ValueError(f"Cannot find supported model: {architectures}")
|
||||||
|
|
||||||
|
def is_multimodal_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
|
||||||
|
"""Check if it's a multimodal model"""
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
model_info = self._try_inspect_model_cls(arch)
|
||||||
|
if model_info is not None:
|
||||||
|
return model_info.is_multimodal
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_text_generation_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
|
||||||
|
"""Check if it's a text generation model"""
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
model_info = self._try_inspect_model_cls(arch)
|
||||||
|
if model_info is not None:
|
||||||
|
return model_info.is_text_generation
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_pooling_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
|
||||||
|
"""Check if it's a pooling model"""
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
model_info = self._try_inspect_model_cls(arch)
|
||||||
|
if model_info is not None:
|
||||||
|
return model_info.is_pooling
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ModelForCasualLM(nn.Layer, ABC):
|
class ModelForCasualLM(nn.Layer, ABC):
|
||||||
@@ -88,7 +361,6 @@ class ModelForCasualLM(nn.Layer, ABC):
|
|||||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||||
"""
|
"""
|
||||||
Load model parameters from a given state dictionary.
|
Load model parameters from a given state dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||||
A dictionary containing model parameters, where keys are parameter names
|
A dictionary containing model parameters, where keys are parameter names
|
||||||
@@ -105,12 +377,10 @@ class ModelForCasualLM(nn.Layer, ABC):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Defines the forward pass of the model for generating text.
|
Defines the forward pass of the model for generating text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids (Tensor, optional): The input token ids to the model.
|
input_ids (Tensor, optional): The input token ids to the model.
|
||||||
pos_emb (Tensor, optional): position Embeddings for model.
|
pos_emb (Tensor, optional): position Embeddings for model.
|
||||||
**model_kwargs: Additional keyword arguments for the model.
|
**model_kwargs: Additional keyword arguments for the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor or list of Tensors: Generated tokens or decoded outputs.
|
Tensor or list of Tensors: Generated tokens or decoded outputs.
|
||||||
"""
|
"""
|
||||||
|
@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MLP(nn.Layer):
|
class Qwen2MLP(nn.Layer):
|
||||||
@@ -282,6 +286,12 @@ class Qwen2Model(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Qwen2ForCausalLM",
|
||||||
|
module_path="qwen2",
|
||||||
|
category=[ModelCategory.TEXT_GENERATION, ModelCategory.EMBEDDING],
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class Qwen2ForCausalLM(ModelForCasualLM):
|
class Qwen2ForCausalLM(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
Qwen2ForCausalLM
|
Qwen2ForCausalLM
|
||||||
|
@@ -33,7 +33,11 @@ from fastdeploy.model_executor.graph_optimization.decorator import (
|
|||||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer
|
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
@@ -157,6 +161,12 @@ class Qwen2_5_VLModel(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Qwen2_5_VLForConditionalGeneration",
|
||||||
|
module_path="qwen2_5_vl.qwen2_5_vl",
|
||||||
|
category=ModelCategory.MULTIMODAL,
|
||||||
|
primary_use=ModelCategory.MULTIMODAL,
|
||||||
|
)
|
||||||
class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
Qwen2_5_VLForConditionalGeneration
|
Qwen2_5_VLForConditionalGeneration
|
||||||
|
@@ -34,8 +34,13 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
|||||||
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP
|
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP
|
||||||
|
from fastdeploy.transformer_utils.config import get_pooling_config
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MLP(Qwen2MLP):
|
class Qwen3MLP(Qwen2MLP):
|
||||||
@@ -218,6 +223,12 @@ class Qwen3Model(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Qwen3ForCausalLM",
|
||||||
|
module_path="qwen3",
|
||||||
|
category=[ModelCategory.TEXT_GENERATION],
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class Qwen3ForCausalLM(ModelForCasualLM):
|
class Qwen3ForCausalLM(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
Qwen3ForCausalLM
|
Qwen3ForCausalLM
|
||||||
@@ -260,6 +271,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
|||||||
process_weights_after_loading,
|
process_weights_after_loading,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_pooling_model = hasattr(self, "is_pooling_model") and self.is_pooling_model
|
||||||
|
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("qkv_proj", "q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
@@ -270,8 +283,18 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
|||||||
("embed_tokens.embeddings", "embed_tokens", None),
|
("embed_tokens.embeddings", "embed_tokens", None),
|
||||||
("lm_head.linear", "lm_head", None),
|
("lm_head.linear", "lm_head", None),
|
||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
model_path = self.fd_config.model_config.model
|
||||||
|
revision = self.fd_config.model_config.revision
|
||||||
|
if is_pooling_model and get_pooling_config(model_path, revision):
|
||||||
|
params_dict = {
|
||||||
|
param_name[6:] if param_name.startswith("model.") else param_name: param
|
||||||
|
for param_name, param in params_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||||
|
|
||||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in loaded_weight_name:
|
if weight_name not in loaded_weight_name:
|
||||||
@@ -282,6 +305,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
|||||||
param = params_dict[model_param_name]
|
param = params_dict[model_param_name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
model_param_name = loaded_weight_name
|
model_param_name = loaded_weight_name
|
||||||
@@ -290,10 +314,11 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
|||||||
param = params_dict[model_param_name]
|
param = params_dict[model_param_name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
|
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
|
||||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||||
|
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings and not is_pooling_model:
|
||||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})
|
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
|
@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
|
|||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import (
|
||||||
|
ModelCategory,
|
||||||
|
ModelForCasualLM,
|
||||||
|
ModelRegistry,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
|
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
|
||||||
|
|
||||||
|
|
||||||
@@ -316,6 +320,12 @@ class Qwen3MoeModel(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register_model_class(
|
||||||
|
architecture="Qwen3MoeForCausalLM",
|
||||||
|
module_path="qwen3moe",
|
||||||
|
category=ModelCategory.TEXT_GENERATION,
|
||||||
|
primary_use=ModelCategory.TEXT_GENERATION,
|
||||||
|
)
|
||||||
class Qwen3MoeForCausalLM(ModelForCasualLM):
|
class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
Qwen3MoeForCausalLM
|
Qwen3MoeForCausalLM
|
||||||
|
@@ -158,6 +158,7 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
|||||||
|
|
||||||
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
|
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
|
||||||
"""fn"""
|
"""fn"""
|
||||||
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
if weight_need_transpose:
|
if weight_need_transpose:
|
||||||
|
69
fastdeploy/output/pooler.py
Normal file
69
fastdeploy/output/pooler.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingSequenceGroupOutput(
|
||||||
|
msgspec.Struct,
|
||||||
|
omit_defaults=True,
|
||||||
|
array_like=True,
|
||||||
|
):
|
||||||
|
"""The model output associated with a pooling sequence group."""
|
||||||
|
|
||||||
|
# Annotated as Any to be compatible with msgspec
|
||||||
|
# The actual type is in SequenceGroup.pooled_data
|
||||||
|
data: Any
|
||||||
|
|
||||||
|
def get_data_nbytes(self) -> int:
|
||||||
|
if isinstance(self.data, paddle.Tensor):
|
||||||
|
return self.data.numel() * self.data.element_size()
|
||||||
|
elif hasattr(self.data, "nbytes"):
|
||||||
|
return self.data.nbytes
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"PoolingSequenceGroupOutput(data={self.data}"
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, PoolingSequenceGroupOutput):
|
||||||
|
raise NotImplementedError()
|
||||||
|
return self.data == other.data
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerOutput(msgspec.Struct, omit_defaults=True, array_like=True):
|
||||||
|
"""The output from a pooling operation in the pooling model."""
|
||||||
|
|
||||||
|
outputs: list[PoolingSequenceGroupOutput]
|
||||||
|
|
||||||
|
def get_data_nbytes(self) -> int:
|
||||||
|
return sum(o.get_data_nbytes() for o in self.outputs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
||||||
|
return self.outputs[idx]
|
||||||
|
|
||||||
|
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
|
||||||
|
self.outputs[idx] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.outputs)
|
||||||
|
|
||||||
|
def __eq__(self, other: object):
|
||||||
|
return isinstance(other, self.__class__) and self.outputs == other.outputs
|
15
fastdeploy/transformer_utils/__init__.py
Normal file
15
fastdeploy/transformer_utils/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
139
fastdeploy/transformer_utils/config.py
Normal file
139
fastdeploy/transformer_utils/config.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
from huggingface_hub import hf_hub_download, try_to_load_from_cache
|
||||||
|
from huggingface_hub.utils import (
|
||||||
|
EntryNotFoundError,
|
||||||
|
HfHubHTTPError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastdeploy.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("transformer_config", "transformer_config.log")
|
||||||
|
|
||||||
|
|
||||||
|
def file_or_path_exists(model, config_name):
|
||||||
|
if (local_path := Path(model)).exists():
|
||||||
|
return (local_path / config_name).is_file()
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_pooling_config_name(pooling_name: str):
|
||||||
|
|
||||||
|
if "pooling_mode_" in pooling_name:
|
||||||
|
pooling_name = pooling_name.replace("pooling_mode_", "")
|
||||||
|
|
||||||
|
if "_" in pooling_name:
|
||||||
|
pooling_name = pooling_name.split("_")[0]
|
||||||
|
|
||||||
|
if "lasttoken" in pooling_name:
|
||||||
|
pooling_name = "last"
|
||||||
|
|
||||||
|
supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||||
|
pooling_type_name = pooling_name.upper()
|
||||||
|
|
||||||
|
if pooling_type_name in supported_pooling_types:
|
||||||
|
return pooling_type_name
|
||||||
|
|
||||||
|
raise NotImplementedError(f"Pooling type {pooling_type_name} not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_local_file(model: Union[str, Path], file_name: str, revision: Optional[str] = "main") -> Optional[Path]:
|
||||||
|
file_path = Path(model) / file_name
|
||||||
|
if file_path.is_file():
|
||||||
|
return file_path
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
cached_filepath = try_to_load_from_cache(repo_id=model, filename=file_name, revision=revision)
|
||||||
|
if isinstance(cached_filepath, str):
|
||||||
|
return Path(cached_filepath)
|
||||||
|
except ValueError:
|
||||||
|
...
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_file_to_dict(file_name: str, model: Union[str, Path], revision: Optional[str] = "main"):
|
||||||
|
"""
|
||||||
|
Downloads a file from the Hugging Face Hub and returns
|
||||||
|
its contents as a dictionary.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file_name (str): The name of the file to download.
|
||||||
|
- model (str): The name of the model on the Hugging Face Hub.
|
||||||
|
- revision (str): The specific version of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- config_dict (dict): A dictionary containing
|
||||||
|
the contents of the downloaded file.
|
||||||
|
"""
|
||||||
|
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
|
||||||
|
|
||||||
|
if file_path is None:
|
||||||
|
try:
|
||||||
|
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
|
||||||
|
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||||
|
return None
|
||||||
|
except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError, LocalEntryNotFoundError) as e:
|
||||||
|
logger.debug("File or repository not found in hf_hub_download", e)
|
||||||
|
return None
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
logger.warning(
|
||||||
|
"Cannot connect to Hugging Face Hub. Skipping file " "download for '%s':", file_name, exc_info=e
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
file_path = Path(hf_hub_file)
|
||||||
|
|
||||||
|
if file_path is not None and file_path.is_file():
|
||||||
|
with open(file_path) as file:
|
||||||
|
return json.load(file)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pooling_config(model: str, revision: Optional[str] = "main"):
|
||||||
|
"""
|
||||||
|
This function gets the pooling and normalize
|
||||||
|
config from the model - only applies to
|
||||||
|
sentence-transformers models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The name of the Hugging Face model.
|
||||||
|
revision (str, optional): The specific version
|
||||||
|
of the model to use. Defaults to 'main'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the pooling
|
||||||
|
type and whether normalization is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
modules_file_name = "modules.json"
|
||||||
|
modules_dict = None
|
||||||
|
if file_or_path_exists(model, config_name=modules_file_name):
|
||||||
|
modules_dict = get_hf_file_to_dict(modules_file_name, model)
|
||||||
|
|
||||||
|
if modules_dict is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
pooling = next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Pooling"), None)
|
||||||
|
|
||||||
|
normalize = bool(
|
||||||
|
next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Normalize"), False)
|
||||||
|
)
|
||||||
|
|
||||||
|
if pooling:
|
||||||
|
pooling_file_name = "{}/config.json".format(pooling["path"])
|
||||||
|
pooling_dict = get_hf_file_to_dict(pooling_file_name, model)
|
||||||
|
pooling_type_name = next((item for item, val in pooling_dict.items() if val is True), None)
|
||||||
|
|
||||||
|
if pooling_type_name is not None:
|
||||||
|
pooling_type_name = get_pooling_config_name(pooling_type_name)
|
||||||
|
|
||||||
|
return {"pooling_type": pooling_type_name, "normalize": normalize}
|
||||||
|
|
||||||
|
return None
|
@@ -51,6 +51,7 @@ from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
|
|||||||
from fastdeploy.logger.logger import FastDeployLogger
|
from fastdeploy.logger.logger import FastDeployLogger
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
# [N,2] -> every line is [config_name, enable_xxx_name]
|
# [N,2] -> every line is [config_name, enable_xxx_name]
|
||||||
# Make sure enable_xxx equal to config.enable_xxx
|
# Make sure enable_xxx equal to config.enable_xxx
|
||||||
@@ -852,3 +853,24 @@ api_server_logger = get_logger("api_server", "api_server.log")
|
|||||||
console_logger = get_logger("console", "console.log", print_to_console=True)
|
console_logger = get_logger("console", "console.log", print_to_console=True)
|
||||||
spec_logger = get_logger("speculate", "speculate.log")
|
spec_logger = get_logger("speculate", "speculate.log")
|
||||||
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
|
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||||
|
|
||||||
|
def _parse_type(val: str) -> T:
|
||||||
|
try:
|
||||||
|
return return_type(val)
|
||||||
|
except ValueError as e:
|
||||||
|
raise argparse.ArgumentTypeError(f"Value {val} cannot be converted to {return_type}.") from e
|
||||||
|
|
||||||
|
return _parse_type
|
||||||
|
|
||||||
|
|
||||||
|
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
||||||
|
|
||||||
|
def _optional_type(val: str) -> Optional[T]:
|
||||||
|
if val == "" or val == "None":
|
||||||
|
return None
|
||||||
|
return parse_type(return_type)(val)
|
||||||
|
|
||||||
|
return _optional_type
|
||||||
|
@@ -1319,8 +1319,12 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.parallel_config.max_model_len,
|
self.parallel_config.max_model_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Execute spec decode
|
logits = None
|
||||||
logits = self.model.compute_logits(hidden_states)
|
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 4. Execute spec decode
|
||||||
|
logits = self.model.compute_logits(hidden_states)
|
||||||
|
|
||||||
if not self.speculative_decoding:
|
if not self.speculative_decoding:
|
||||||
set_value_by_flags_and_idx(
|
set_value_by_flags_and_idx(
|
||||||
@@ -1625,8 +1629,13 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.parallel_config.max_model_len,
|
self.parallel_config.max_model_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logits = None
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
logits = self.model.compute_logits(hidden_states)
|
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 4. Execute spec decode
|
||||||
|
logits = self.model.compute_logits(hidden_states)
|
||||||
|
|
||||||
if not self.speculative_decoding:
|
if not self.speculative_decoding:
|
||||||
set_value_by_flags_and_idx(
|
set_value_by_flags_and_idx(
|
||||||
|
@@ -45,7 +45,7 @@ from fastdeploy.inter_communicator import IPCSignal
|
|||||||
from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.scheduler import SchedulerConfig
|
from fastdeploy.scheduler import SchedulerConfig
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger, optional_type
|
||||||
from fastdeploy.worker.worker_base import WorkerBase
|
from fastdeploy.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = get_logger("worker_process", "worker_process.log")
|
logger = get_logger("worker_process", "worker_process.log")
|
||||||
@@ -643,6 +643,27 @@ def parse_args():
|
|||||||
help="Flag to specify dtype of lm_head as FP32",
|
help="Flag to specify dtype of lm_head as FP32",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--runner",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
help="The type of model runner to use.Each FD instance only supports one model runner.even if the same model can be used for multiple types.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--convert",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
help="Convert the model using adapters. The most common use case is to adapt a text generation model to be used for pooling tasks.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--override-pooler-config",
|
||||||
|
type=optional_type(json.loads),
|
||||||
|
default=None,
|
||||||
|
help="Override configuration for the pooler.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@@ -39,3 +39,4 @@ opentelemetry-distro
|
|||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-fastapi
|
opentelemetry-instrumentation-fastapi
|
||||||
partial_json_parser
|
partial_json_parser
|
||||||
|
msgspec
|
||||||
|
@@ -14,9 +14,8 @@
|
|||||||
|
|
||||||
from paddleformers.transformers import PretrainedModel
|
from paddleformers.transformers import PretrainedModel
|
||||||
|
|
||||||
from fastdeploy import ModelRegistry
|
|
||||||
from fastdeploy.config import ErnieArchitectures
|
from fastdeploy.config import ErnieArchitectures
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import ModelForCasualLM, ModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class MyPretrainedModel(PretrainedModel):
|
class MyPretrainedModel(PretrainedModel):
|
||||||
|
182
tests/pooling/test_embedding.py
Normal file
182
tests/pooling/test_embedding.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastdeploy.config import (
|
||||||
|
CacheConfig,
|
||||||
|
FDConfig,
|
||||||
|
GraphOptimizationConfig,
|
||||||
|
LoadConfig,
|
||||||
|
ModelConfig,
|
||||||
|
ParallelConfig,
|
||||||
|
)
|
||||||
|
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, ".."))
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
from tests.model_loader.utils import get_torch_model_path
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelLoader:
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def setup_paddle(self):
|
||||||
|
if not paddle.is_compiled_with_cuda():
|
||||||
|
print("CUDA not available, using CPU")
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
else:
|
||||||
|
print("Using CUDA device")
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
yield
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def model_path(self):
|
||||||
|
try:
|
||||||
|
torch_model_path = get_torch_model_path("Qwen3-0.6B")
|
||||||
|
if os.path.exists(torch_model_path):
|
||||||
|
return torch_model_path
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not get torch model path: {e}")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_config(self, model_path):
|
||||||
|
model_args = {
|
||||||
|
"model": model_path,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"max_model_len": 8192,
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"runner": "auto",
|
||||||
|
"convert": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ModelConfig(model_args)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not create ModelConfig: {e}")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fd_config(self, model_config):
|
||||||
|
try:
|
||||||
|
cache_args = {
|
||||||
|
"block_size": 64,
|
||||||
|
"gpu_memory_utilization": 0.9,
|
||||||
|
"cache_dtype": "bfloat16",
|
||||||
|
"model_cfg": model_config,
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
}
|
||||||
|
cache_config = CacheConfig(cache_args)
|
||||||
|
|
||||||
|
parallel_args = {
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"data_parallel_size": 1,
|
||||||
|
}
|
||||||
|
parallel_config = ParallelConfig(parallel_args)
|
||||||
|
|
||||||
|
load_args = {}
|
||||||
|
load_config = LoadConfig(load_args)
|
||||||
|
|
||||||
|
graph_opt_args = {
|
||||||
|
"enable_cudagraph": False,
|
||||||
|
"cudagraph_capture_sizes": None,
|
||||||
|
}
|
||||||
|
graph_opt_config = GraphOptimizationConfig(graph_opt_args)
|
||||||
|
|
||||||
|
return FDConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
load_config=load_config,
|
||||||
|
graph_opt_config=graph_opt_config,
|
||||||
|
test_mode=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not create FDConfig: {e}")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_json_config(self, model_path):
|
||||||
|
config_path = os.path.join(model_path, "config.json")
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def test_embedding_with_none_convert_type(self, fd_config, model_json_config):
|
||||||
|
if model_json_config is None:
|
||||||
|
pytest.skip("Model config not available")
|
||||||
|
|
||||||
|
if fd_config is None:
|
||||||
|
pytest.skip("FDConfig not available")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing initialize_model with convert_type='none'")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
architectures = model_json_config.get("architectures", [])
|
||||||
|
if not architectures:
|
||||||
|
pytest.skip("No architectures found in model config")
|
||||||
|
|
||||||
|
fd_config.model_config.convert_type = "none"
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_cls = ModelRegistry.get_class(architectures)
|
||||||
|
|
||||||
|
if hasattr(model_cls, "__name__"):
|
||||||
|
assert (
|
||||||
|
"ForEmbedding" not in model_cls.__name__
|
||||||
|
), f"Standard model should not have 'ForEmbedding' in name, but got: {model_cls.__name__}"
|
||||||
|
print(f"Confirmed standard model type (no ForEmbedding): {model_cls.__name__}")
|
||||||
|
|
||||||
|
standard_methods = set(dir(model_cls))
|
||||||
|
assert "_init_pooler" not in standard_methods, "Standard model should not have _init_pooler method"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in none: {e}")
|
||||||
|
|
||||||
|
def test_embedding_with_embed_convert_type(self, fd_config, model_json_config):
|
||||||
|
if model_json_config is None:
|
||||||
|
pytest.skip("Model config not available")
|
||||||
|
|
||||||
|
if fd_config is None:
|
||||||
|
pytest.skip("FDConfig not available")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing embedding with convert_type='embed'")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
architectures = model_json_config.get("architectures", [])
|
||||||
|
if not architectures:
|
||||||
|
pytest.skip("No architectures found in model config")
|
||||||
|
|
||||||
|
fd_config.model_config.convert_type = "embed"
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_cls = ModelRegistry.get_class(architectures)
|
||||||
|
if hasattr(model_cls, "__name__"):
|
||||||
|
assert "ForEmbedding" in model_cls.__name__, "Embedding model should have 'ForEmbedding' in name"
|
||||||
|
print(f"Confirmed embedding model type: {model_cls.__name__}")
|
||||||
|
|
||||||
|
embedding_methods = set(dir(model_cls))
|
||||||
|
assert "_init_pooler" in embedding_methods, "Embedding model should have _init_pooler method"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in convert embed: {e}")
|
Reference in New Issue
Block a user