[Feature] support pool (#3827)

* support pool

* update pooling

* add pooler_config and check

* update

* support AutoWeightsLoader load weight

* fix

* update

* delete print

* update pre-commit

* fix

* fix xpu

* fix ModelRegistry->model_registry

* fix Copilot review

* fix pooler.py

* delete StepPooler

* fix abstract

* fix default_loader_v1

* fix Pre Commit

* support torch qwen3 dense

* add test and fix torch-qwen

* fix

* fix

* adapter ci:

* fix review

* fix pooling_params.py

* fix

* fix tasks.py 2025

* fix print and logger

* Modefy ModelRegistry and delete AutoWeightsLoader

* fix logger

* fix test_embedding

* fix ci bug

* ernie4_5 model_registry

* fix test

* support Qwen3-Embedding-0.6B tp=1 load

* fix extra code

* fix

* delete fix vocab_size

* delete prepare_params_dict

* fix:
This commit is contained in:
lizexu123
2025-09-22 14:09:09 +08:00
committed by GitHub
parent da74a5f0b3
commit c86945ef49
36 changed files with 2371 additions and 51 deletions

View File

@@ -18,12 +18,14 @@ from __future__ import annotations
import json
import os
from dataclasses import field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
import paddle
import paddle.distributed as dist
from paddleformers.transformers.configuration_utils import PretrainedConfig
from typing_extensions import assert_never
import fastdeploy
from fastdeploy import envs
@@ -31,11 +33,68 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfig
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.transformer_utils.config import get_pooling_config
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
logger = get_logger("config", "config.log")
TaskOption = Literal["generate"]
TaskOption = Literal["auto", "generate", "embedding", "embed"]
RunnerType = Literal["generate", "pooling"]
RunnerOption = Literal["auto", "generate", "pooling"]
ConvertOption = Literal["auto", "none", "embed"]
ConvertType = Literal["none", "embed"]
_ResolvedTask = Literal["generate", "encode", "embed"]
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [],
"pooling": ["embed"],
}
# Some model suffixes are based on auto classes from Transformers:
# https://huggingface.co/docs/transformers/en/model_doc/auto
# NOTE: Items higher on this list priority over lower ones
_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForCausalLM", ("generate", "none")),
("ForConditionalGeneration", ("generate", "none")),
("ChatModel", ("generate", "none")),
("LMHeadModel", ("generate", "none")),
("ForTextEncoding", ("pooling", "embed")),
("EmbeddingModel", ("pooling", "embed")),
("ForSequenceClassification", ("pooling", "classify")),
("ForAudioClassification", ("pooling", "classify")),
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),
("ClassificationModel", ("pooling", "classify")),
("ForRewardModeling", ("pooling", "reward")),
("RewardModel", ("pooling", "reward")),
# Let other `*Model`s take priority
("Model", ("pooling", "embed")),
]
def iter_architecture_defaults():
yield from _SUFFIX_TO_DEFAULTS
def try_match_architecture_defaults(
architecture: str,
*,
runner_type: Optional[RunnerType] = None,
convert_type: Optional[ConvertType] = None,
):
for suffix, (default_runner_type, default_convert_type) in iter_architecture_defaults():
if (
(runner_type is None or runner_type == default_runner_type)
and (convert_type is None or convert_type == default_convert_type)
and architecture.endswith(suffix)
):
return suffix, (default_runner_type, default_convert_type)
return None
class MoEPhase:
@@ -133,6 +192,12 @@ class ModelConfig:
self.eos_tokens_lens: int = 2
self.lm_head_fp32: bool = False
self.model_format = "auto"
self.runner = "auto"
self.convert = "auto"
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
self.revision = None
self.partial_rotary_factor: float = 1.0
self.num_nextn_predict_layers = 0
for key, value in args.items():
@@ -161,6 +226,7 @@ class ModelConfig:
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
architectures = self.architectures[0]
if MultimodalRegistry.contains_model(architectures):
self.enable_mm = True
else:
@@ -171,6 +237,43 @@ class ModelConfig:
self.override_name_from_config()
self.read_from_env()
self.read_model_config()
self.runner_type = self._get_runner_type(self.architectures, self.runner)
self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert)
registry = self.registry
is_generative_model = registry.is_text_generation_model(self.architectures, self)
is_pooling_model = registry.is_pooling_model(self.architectures, self)
is_multimodal_model = registry.is_multimodal_model(self.architectures, self)
if self.runner_type == "generate" and not is_generative_model:
if is_multimodal_model:
pass
else:
generate_converts = _RUNNER_CONVERTS["generate"]
if self.convert_type not in generate_converts:
raise ValueError("This model does not support '--runner generate.")
if self.runner_type == "pooling" and not is_pooling_model:
pooling_converts = _RUNNER_CONVERTS["pooling"]
if self.convert_type not in pooling_converts:
convert_option = "<" + "|".join(pooling_converts) + ">"
raise ValueError(
"This model does not support `--runner pooling`. "
f"You can pass `--convert {convert_option} to adapt "
"it into a pooling model."
)
self.supported_tasks = self._get_supported_tasks(self.architectures, self.runner_type, self.convert_type)
model_info, arch = registry.inspect_model_cls(self.architectures, self)
self._model_info = model_info
self._architecture = arch
self.pooler_config = self._init_pooler_config()
@property
def registry(self):
from fastdeploy.model_executor.models.model_base import ModelRegistry
return ModelRegistry()
def override_name_from_config(self):
"""
@@ -194,7 +297,6 @@ class ModelConfig:
def read_from_env(self):
"""
Read configuration information from environment variables and update the object's attributes.
If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
@@ -235,6 +337,165 @@ class ModelConfig:
f"Config file path: {config_path}"
)
def _get_default_runner_type(
self,
architectures: list[str],
) -> RunnerType:
registry = self.registry
if get_pooling_config(self.model, self.revision):
return "pooling"
for arch in architectures:
if arch in registry.get_supported_archs():
if registry.is_pooling_model(architectures, self):
return "pooling"
if registry.is_text_generation_model(architectures, self):
return "generate"
match = try_match_architecture_defaults(arch)
if match:
_, (runner_type, _) = match
return runner_type
return "generate"
def _get_default_convert_type(
self,
architectures: list[str],
runner_type: RunnerType,
) -> ConvertType:
registry = self.registry
for arch in architectures:
if arch in registry.get_supported_archs():
if runner_type == "generate" and registry.is_text_generation_model(architectures, self):
return "none"
if runner_type == "pooling" and registry.is_pooling_model(architectures, self):
return "none"
match = try_match_architecture_defaults(arch, runner_type=runner_type)
if match:
_, (_, convert_type) = match
return convert_type
# This is to handle Sentence Transformers models that use *ForCausalLM
# and also multi-modal pooling models which are not defined as
# Sentence Transformers models
if runner_type == "pooling":
return "embed"
return "none"
def _get_runner_type(
self,
architectures: list[str],
runner: RunnerOption,
) -> RunnerType:
if runner != "auto":
return runner
runner_type = self._get_default_runner_type(architectures)
if runner_type != "generate":
logger.info(
"Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.",
runner_type,
)
return runner_type
def _get_convert_type(
self,
architectures: list[str],
runner_type: RunnerType,
convert: ConvertOption,
) -> ConvertType:
if convert != "auto":
return convert
convert_type = self._get_default_convert_type(architectures, runner_type)
if convert_type != "none":
logger.info(
"Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.",
convert_type,
)
return convert_type
def _get_supported_generation_tasks(
self,
architectures: list[str],
convert_type: ConvertType,
) -> list[_ResolvedTask]:
registry = self.registry
supported_tasks = list[_ResolvedTask]()
if registry.is_text_generation_model(architectures, self) or convert_type in _RUNNER_CONVERTS["generate"]:
supported_tasks.append("generate")
# TODO:Temporarily does not support transcription.
return supported_tasks
def _get_default_pooling_task(
self,
architectures: list[str],
) -> Literal["embed"]:
# Temporarily does not support classification and reward.
for arch in architectures:
match = try_match_architecture_defaults(arch, runner_type="pooling")
if match:
_, (_, convert_type) = match
assert convert_type != "none"
return convert_type
return "embed"
def _get_supported_pooling_tasks(
self,
architectures: list[str],
convert_type: ConvertType,
) -> list[_ResolvedTask]:
registry = self.registry
supported_tasks = list[_ResolvedTask]()
if registry.is_pooling_model(architectures, self) or convert_type in _RUNNER_CONVERTS["pooling"]:
supported_tasks.append("encode")
extra_task = self._get_default_pooling_task(architectures) if convert_type == "none" else convert_type
supported_tasks.append(extra_task)
return supported_tasks
def _get_supported_tasks(
self,
architectures: list[str],
runner_type: RunnerType,
convert_type: ConvertType,
) -> list[_ResolvedTask]:
if runner_type == "generate":
return self._get_supported_generation_tasks(architectures, convert_type)
if runner_type == "pooling":
return self._get_supported_pooling_tasks(architectures, convert_type)
assert_never(runner_type)
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(**self.override_pooler_config)
pooler_config = self.override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
for k, v in base_config.items():
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type
if pooler_config.pooling_type is None:
pooler_config.pooling_type = default_pooling_type
return pooler_config
return None
def _get_download_model(self, model_name, model_type="default"):
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
pass
@@ -846,6 +1107,41 @@ class LoadConfig:
setattr(self, key, value)
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: Optional[str] = None
"""
The pooling method of the pooling model.
"""
# for embeddings models
normalize: Optional[bool] = None
"""
Whether to normalize the embeddings outputs. Defaults to True.
"""
dimensions: Optional[int] = None
"""
Reduce the dimensions of embeddings if model
support matryoshka representation. Defaults to None.
"""
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
When an input exceeds max_embed_len, it will be handled according to
the original max_model_len validation logic.
Defaults to None (i.e. set to max_model_len).
"""
class LoRAConfig:
"""LoRA Config"""