mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-19 06:54:41 +08:00
[Feature] support pool (#3827)
* support pool * update pooling * add pooler_config and check * update * support AutoWeightsLoader load weight * fix * update * delete print * update pre-commit * fix * fix xpu * fix ModelRegistry->model_registry * fix Copilot review * fix pooler.py * delete StepPooler * fix abstract * fix default_loader_v1 * fix Pre Commit * support torch qwen3 dense * add test and fix torch-qwen * fix * fix * adapter ci: * fix review * fix pooling_params.py * fix * fix tasks.py 2025 * fix print and logger * Modefy ModelRegistry and delete AutoWeightsLoader * fix logger * fix test_embedding * fix ci bug * ernie4_5 model_registry * fix test * support Qwen3-Embedding-0.6B tp=1 load * fix extra code * fix * delete fix vocab_size * delete prepare_params_dict * fix:
This commit is contained in:
@@ -47,8 +47,10 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
|
||||
module = importlib.import_module(f"{register_path}.{module_file}")
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
|
||||
if inspect.isclass(attr) and issubclass(attr, ModelForCasualLM) and attr is not ModelForCasualLM:
|
||||
ModelRegistry.register_model_class(attr)
|
||||
|
||||
if (
|
||||
inspect.isclass(attr)
|
||||
and issubclass(attr, PretrainedModel)
|
||||
@@ -56,6 +58,7 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
|
||||
and hasattr(attr, "arch_name")
|
||||
):
|
||||
ModelRegistry.register_pretrained_model(attr)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(f"{module_file=} import error")
|
||||
|
||||
|
214
fastdeploy/model_executor/models/adapters.py
Normal file
214
fastdeploy/model_executor/models/adapters.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.model_executor.layers.activation import get_act_fn
|
||||
from fastdeploy.model_executor.models.interfaces_base import is_pooling_model
|
||||
from fastdeploy.transformer_utils.config import get_hf_file_to_dict
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Layer])
|
||||
|
||||
_GENERATE_SUFFIXES = [
|
||||
"ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
"ChatModel",
|
||||
"LMHeadModel",
|
||||
]
|
||||
|
||||
|
||||
def _load_dense_weights(linear: nn.Linear, folder: str, model_config: "ModelConfig") -> bool:
|
||||
"""Load weights using vLLM's weight_loader pattern."""
|
||||
|
||||
from fastdeploy.model_executor.utils import default_weight_loader
|
||||
|
||||
filename = "model.safetensors"
|
||||
file_path = f"{folder}/{filename}" if folder else filename
|
||||
|
||||
try:
|
||||
file_bytes = get_hf_file_to_dict(file_path, model_config.model, model_config.revision)
|
||||
if not file_bytes:
|
||||
return False
|
||||
|
||||
state_dict = {}
|
||||
if filename.endswith(".safetensors"):
|
||||
import io
|
||||
|
||||
from safetensors.numpy import load as load_safetensors
|
||||
|
||||
numpy_tensors = load_safetensors(io.BytesIO(file_bytes))
|
||||
for key, numpy_array in numpy_tensors.items():
|
||||
state_dict[key] = paddle.to_tensor(numpy_array)
|
||||
else:
|
||||
import io
|
||||
|
||||
state_dict = paddle.load(io.BytesIO(file_bytes))
|
||||
|
||||
weight_keys = ["weight", "linear.weight", "dense.weight"]
|
||||
|
||||
for weight_key in weight_keys:
|
||||
if weight_key in state_dict:
|
||||
weight_loader = getattr(linear.weight, "weight_loader", default_weight_loader)
|
||||
weight_loader(linear.weight, state_dict[weight_key].astype(paddle.float32))
|
||||
bias_key = weight_key.replace("weight", "bias")
|
||||
if linear.bias is not None and bias_key in state_dict:
|
||||
bias_loader = getattr(linear.bias, "weight_loader", default_weight_loader)
|
||||
bias_loader(linear.bias, state_dict[bias_key].astype(paddle.float32))
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to load :{e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Layer]:
|
||||
try:
|
||||
modules = get_hf_file_to_dict("modules.json", model_config.model, model_config.revision)
|
||||
if not modules:
|
||||
return None
|
||||
|
||||
if isinstance(modules, dict):
|
||||
modules = modules.get("modules", [])
|
||||
|
||||
dense_modules = [m for m in modules if m.get("type") == "sentence_transformers.models.Dense"]
|
||||
if not dense_modules:
|
||||
return None
|
||||
|
||||
layers = []
|
||||
for module in dense_modules:
|
||||
folder = module.get("path", "")
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(config_path, model_config.model, model_config.revision)
|
||||
if not layer_config:
|
||||
continue
|
||||
linear = nn.Linear(
|
||||
layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
)
|
||||
linear = linear.astype(paddle.float32)
|
||||
|
||||
if not _load_dense_weights(linear, folder, model_config):
|
||||
continue
|
||||
|
||||
layers.append(linear)
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
return nn.Sequential(*layers).astype(paddle.float32)
|
||||
except Exception as e:
|
||||
print(f"ST projector loading failed:{e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
|
||||
class ModelForPooling(orig_cls):
|
||||
|
||||
def __init__(self, fd_config, *args, **kwargs):
|
||||
super().__init__(fd_config, *args, **kwargs)
|
||||
self.fd_config = fd_config
|
||||
self.is_pooling_model = True
|
||||
|
||||
# These are not used in pooling models
|
||||
for attr in ("lm_head", "logits_processor"):
|
||||
if hasattr(self, attr):
|
||||
delattr(self, attr)
|
||||
|
||||
# If the model already defines a pooler instance, don't overwrite it
|
||||
if not getattr(self, "pooler", None):
|
||||
self._init_pooler(fd_config)
|
||||
|
||||
def _init_pooler(self, fd_config):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, paddle.Tensor]]):
|
||||
# TODO: Support uninitialized params tracking
|
||||
|
||||
# We have deleted this attribute, so don't load it
|
||||
weights = ((name, data) for name, data in weights if not name.startswith("lm_head."))
|
||||
|
||||
# If `*ForCausalLM` defines `load_weights` on the inner model
|
||||
# and there are no other inner modules with parameters,
|
||||
# we support loading from both `*Model` and `*ForCausalLM`
|
||||
|
||||
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
|
||||
# Whether only `self.model` contains parameters
|
||||
model_is_only_param = all(
|
||||
name == "model" or not any(child.parameters()) for name, child in self.named_children()
|
||||
)
|
||||
if model_is_only_param:
|
||||
weights = ((name[6:], data) for name, data in weights if name.startswith("model."))
|
||||
loaded_params = self.model.load_weights(weights)
|
||||
loaded_params = {f"model.{name}" for name in loaded_params}
|
||||
return loaded_params
|
||||
|
||||
# For most other models
|
||||
if hasattr(orig_cls, "load_weights"):
|
||||
return orig_cls.load_weights(self, weights) # type: ignore
|
||||
# Fallback
|
||||
else:
|
||||
raise ValueError("No load_weights method found in the model.")
|
||||
|
||||
return ModelForPooling
|
||||
|
||||
|
||||
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
model_name = orig_model_name
|
||||
|
||||
for generate_suffix in _GENERATE_SUFFIXES:
|
||||
model_name = model_name.removesuffix(generate_suffix)
|
||||
return model_name + pooling_suffix
|
||||
|
||||
|
||||
def as_embedding_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM model to support embeddings.
|
||||
|
||||
By default, the embeddings of the whole prompt are extracted from the
|
||||
normalized hidden state corresponding to the last token.
|
||||
|
||||
Note:
|
||||
We assume that no extra layers are added to the original model;
|
||||
please implement your own model if this is not the case.
|
||||
"""
|
||||
# Avoid modifying existing embedding models
|
||||
if is_pooling_model(cls):
|
||||
return cls
|
||||
|
||||
from fastdeploy.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
|
||||
class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
||||
|
||||
def _init_pooler(self, fd_config, prefix: str = ""):
|
||||
pooler_config = fd_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config, fd_config.model_config),
|
||||
"embed": Pooler.for_embed(pooler_config, fd_config.model_config),
|
||||
},
|
||||
)
|
||||
|
||||
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||
|
||||
return ModelForEmbedding
|
@@ -48,7 +48,11 @@ from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
)
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@@ -588,6 +592,12 @@ class DeepSeekV3Model(nn.Layer):
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="DeepseekV3ForCausalLM",
|
||||
module_path="deepseek_v3",
|
||||
category=ModelCategory.TEXT_GENERATION,
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
DeepseekV3ForCausalLM
|
||||
|
@@ -45,7 +45,11 @@ from fastdeploy.model_executor.layers.linear import (
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||
@@ -471,6 +475,12 @@ class Ernie4_5_Model(nn.Layer):
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Ernie4_5_MoeForCausalLM",
|
||||
module_path="ernie4_5_moe",
|
||||
category=ModelCategory.TEXT_GENERATION,
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
Ernie4_5_MoeForCausalLM
|
||||
@@ -646,6 +656,12 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Ernie4_5_ForCausalLM",
|
||||
module_path="ernie4_5_moe",
|
||||
category=ModelCategory.TEXT_GENERATION,
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
||||
"""
|
||||
Ernie4_5_ForCausalLM
|
||||
@@ -659,6 +675,12 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
||||
return "Ernie4_5_ForCausalLM"
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Ernie4_5ForCausalLM",
|
||||
module_path="ernie4_5_moe",
|
||||
category=ModelCategory.TEXT_GENERATION,
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class Ernie4_5ForCausalLM(Ernie4_5_ForCausalLM):
|
||||
"""
|
||||
Ernie4_5ForCausalLM 0.3B-PT
|
||||
|
@@ -31,7 +31,11 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
||||
@@ -325,6 +329,12 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Ernie4_5_MTPForCausalLM",
|
||||
module_path="ernie4_5_mtp",
|
||||
category=ModelCategory.TEXT_GENERATION,
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
Ernie4_5_MTPForCausalLM
|
||||
|
@@ -44,7 +44,11 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (
|
||||
Ernie4_5_Attention,
|
||||
Ernie4_5_MLP,
|
||||
)
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@@ -792,6 +796,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Ernie4_5_VLMoeForConditionalGeneration",
|
||||
module_path="ernie4_5_vl.ernie4_5_vl_moe",
|
||||
category=ModelCategory.MULTIMODAL,
|
||||
primary_use=ModelCategory.MULTIMODAL,
|
||||
)
|
||||
class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
Ernie4_5_MoePretrainedModel
|
||||
|
@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
class Glm4MoeMLP(nn.Layer):
|
||||
@@ -363,6 +367,12 @@ class Glm4MoeModel(nn.Layer):
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Glm4MoeForCausalLM",
|
||||
module_path="glm4_moe",
|
||||
category=ModelCategory.TEXT_GENERATION,
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class Glm4MoeForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
Glm4MoeForCausalLM
|
||||
|
54
fastdeploy/model_executor/models/interfaces_base.py
Normal file
54
fastdeploy/model_executor/models/interfaces_base.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Type
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:
|
||||
from .model_base import ModelForCasualLM
|
||||
|
||||
return issubclass(model_cls, ModelForCasualLM)
|
||||
|
||||
|
||||
def is_pooling_model(model_cls: Type[nn.Layer]) -> bool:
|
||||
class_name = model_cls.__name__
|
||||
pooling_indicators = ["Embedding", "ForSequenceClassification"]
|
||||
return (
|
||||
any(indicator in class_name for indicator in pooling_indicators)
|
||||
or hasattr(model_cls, "is_embedding_model")
|
||||
and model_cls.is_embedding_model
|
||||
)
|
||||
|
||||
|
||||
def is_multimodal_model(class_name: str) -> bool:
|
||||
multimodal_indicators = ["VL", "Vision", "ConditionalGeneration"]
|
||||
return any(indicator in class_name for indicator in multimodal_indicators)
|
||||
|
||||
|
||||
def determine_model_category(class_name: str):
|
||||
from fastdeploy.model_executor.models.model_base import ModelCategory
|
||||
|
||||
if any(pattern in class_name for pattern in ["VL", "Vision", "ConditionalGeneration"]):
|
||||
return ModelCategory.MULTIMODAL
|
||||
elif any(pattern in class_name for pattern in ["Embedding", "ForSequenceClassification"]):
|
||||
return ModelCategory.EMBEDDING
|
||||
return ModelCategory.TEXT_GENERATION
|
||||
|
||||
|
||||
def get_default_pooling_type(model_cls: Type[nn.Layer] = None) -> str:
|
||||
if model_cls is not None:
|
||||
return getattr(model_cls, "default_pooling_type", "LAST")
|
||||
return "LAST"
|
@@ -3,40 +3,269 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,31 +11,265 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
|
||||
from fastdeploy.config import (
|
||||
ModelConfig,
|
||||
iter_architecture_defaults,
|
||||
try_match_architecture_defaults,
|
||||
)
|
||||
from fastdeploy.model_executor.models.interfaces_base import (
|
||||
determine_model_category,
|
||||
get_default_pooling_type,
|
||||
is_multimodal_model,
|
||||
is_pooling_model,
|
||||
is_text_generation_model,
|
||||
)
|
||||
|
||||
|
||||
class ModelCategory(Enum):
|
||||
TEXT_GENERATION = "text_generation"
|
||||
MULTIMODAL = "multimodal"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInfo:
|
||||
architecture: str
|
||||
category: ModelCategory
|
||||
is_text_generation: bool
|
||||
is_multimodal: bool
|
||||
is_pooling: bool
|
||||
module_path: str
|
||||
default_pooling_type: str
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model_cls: Type[nn.Layer], module_path: str = "") -> "ModelInfo":
|
||||
return ModelInfo(
|
||||
architecture=model_cls.__name__,
|
||||
category=determine_model_category(model_cls.__name__),
|
||||
is_text_generation=is_text_generation_model(model_cls),
|
||||
is_multimodal=is_multimodal_model(model_cls.__name__),
|
||||
is_pooling=is_pooling_model(model_cls),
|
||||
default_pooling_type=get_default_pooling_type(model_cls),
|
||||
module_path=module_path,
|
||||
)
|
||||
|
||||
|
||||
class BaseRegisteredModel(ABC):
|
||||
"""Base class for registered models"""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_cls(self) -> Type[nn.Layer]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def inspect_model_cls(self) -> ModelInfo:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LazyRegisteredModel(BaseRegisteredModel):
|
||||
"""Lazy loaded model"""
|
||||
|
||||
module_name: str
|
||||
class_name: str
|
||||
|
||||
def load_model_cls(self) -> Type[nn.Layer]:
|
||||
try:
|
||||
full_module = f"fastdeploy.model_executor.models.{self.module_name}"
|
||||
module = importlib.import_module(full_module)
|
||||
return getattr(module, self.class_name)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ImportError(f"Failed to load {self.class_name}: {e}")
|
||||
|
||||
def inspect_model_cls(self) -> ModelInfo:
|
||||
model_cls = self.load_model_cls()
|
||||
return ModelInfo.from_model_cls(model_cls, self.module_name)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisteredModel(BaseRegisteredModel):
|
||||
|
||||
model_cls: Type[nn.Layer]
|
||||
|
||||
def load_model_cls(self) -> Type[nn.Layer]:
|
||||
return self.model_cls
|
||||
|
||||
def inspect_model_cls(self) -> ModelInfo:
|
||||
return ModelInfo.from_model_cls(self.model_cls)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _try_inspect_model_cls(
|
||||
model_arch: str,
|
||||
model: BaseRegisteredModel,
|
||||
) -> Optional[ModelInfo]:
|
||||
try:
|
||||
return model.inspect_model_cls()
|
||||
except Exception:
|
||||
print("Error in inspecting model architecture '%s'", model_arch)
|
||||
return None
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""
|
||||
Used to register and retrieve model classes.
|
||||
"""
|
||||
|
||||
_arch_to_model_cls = {}
|
||||
_arch_to_pretrained_model_cls = {}
|
||||
_enhanced_models: Dict[str, Dict] = {}
|
||||
|
||||
def __init__(self):
|
||||
self.models: Dict[str, BaseRegisteredModel] = {}
|
||||
self.pretrained_models: Dict[str, Type[PretrainedModel]] = {}
|
||||
self._registered_models: Dict[str, BaseRegisteredModel] = {}
|
||||
self._register_enhanced_models()
|
||||
|
||||
def _register_enhanced_models(self):
|
||||
for arch, model_info in self._enhanced_models.items():
|
||||
model = LazyRegisteredModel(module_name=model_info["module_path"], class_name=model_info["class_name"])
|
||||
self.models[arch] = model
|
||||
self._registered_models[arch] = model
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _try_load_model_cls(self, architecture: str) -> Optional[Type[nn.Layer]]:
|
||||
if architecture not in self.models:
|
||||
return None
|
||||
try:
|
||||
return self.models[architecture].load_model_cls()
|
||||
except Exception as e:
|
||||
print(f"Failed to load model {architecture}: {e}")
|
||||
return None
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _try_inspect_model_cls(self, model_arch: str) -> Optional[ModelInfo]:
|
||||
if model_arch not in self.models:
|
||||
return None
|
||||
try:
|
||||
return self.models[model_arch].inspect_model_cls()
|
||||
except Exception as e:
|
||||
print(f"Failed to inspect model {model_arch}: {e}")
|
||||
return None
|
||||
|
||||
def _normalize_arch(self, architecture: str, model_config: ModelConfig) -> str:
|
||||
if architecture in self.models:
|
||||
return architecture
|
||||
|
||||
match = try_match_architecture_defaults(
|
||||
architecture,
|
||||
runner_type=getattr(model_config, "runner_type", None),
|
||||
convert_type=getattr(model_config, "convert_type", None),
|
||||
)
|
||||
if match:
|
||||
suffix, _ = match
|
||||
for repl_suffix, _ in iter_architecture_defaults():
|
||||
base_arch = architecture.replace(suffix, repl_suffix)
|
||||
if base_arch in self.models:
|
||||
return base_arch
|
||||
|
||||
return architecture
|
||||
|
||||
def _raise_for_unsupported(self, architectures: list[str]):
|
||||
all_supported_archs = self.get_supported_archs()
|
||||
|
||||
if any(arch in all_supported_archs for arch in architectures):
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} failed to be inspected. "
|
||||
"Please check the logs for more details."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {all_supported_archs}"
|
||||
)
|
||||
|
||||
def inspect_model_cls(
|
||||
self, architectures: Union[str, List[str]], model_config: ModelConfig = None
|
||||
) -> Tuple[ModelInfo, str]:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
|
||||
if not architectures:
|
||||
raise ValueError("No model architectures are specified")
|
||||
|
||||
for arch in architectures:
|
||||
normalized_arch = self._normalize_arch(arch, model_config)
|
||||
model_info = self._try_inspect_model_cls(normalized_arch)
|
||||
if model_info is not None:
|
||||
return (model_info, arch)
|
||||
|
||||
return self._raise_for_unsupported(architectures)
|
||||
|
||||
@classmethod
|
||||
def register_model_class(cls, model_class):
|
||||
"""register model class"""
|
||||
if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM:
|
||||
cls._arch_to_model_cls[model_class.name()] = model_class
|
||||
return model_class
|
||||
def register_model_class(
|
||||
cls,
|
||||
model_class=None,
|
||||
*,
|
||||
architecture: str = None,
|
||||
module_path: str = None,
|
||||
category: Union[ModelCategory, List[ModelCategory]] = ModelCategory.TEXT_GENERATION,
|
||||
primary_use: ModelCategory = None,
|
||||
):
|
||||
"""
|
||||
Enhanced model class registration supporting both traditional and decorator-style registration.
|
||||
|
||||
Can be used as:
|
||||
1. Traditional decorator: @ModelRegistry.register_model_class
|
||||
2. Enhanced decorator with metadata: @ModelRegistry.register_model_class(architecture="...", module_path="...")
|
||||
|
||||
Args:
|
||||
model_class: The model class (when used as simple decorator)
|
||||
architecture (str): Unique identifier for the model architecture
|
||||
module_path (str): Relative path to the module containing the model
|
||||
category: Model category or list of categories
|
||||
primary_use: Primary category for multi-category models
|
||||
"""
|
||||
|
||||
def _register(model_cls):
|
||||
# Traditional registration for ModelForCasualLM subclasses
|
||||
if issubclass(model_cls, ModelForCasualLM) and model_cls is not ModelForCasualLM:
|
||||
cls._arch_to_model_cls[model_cls.name()] = model_cls
|
||||
|
||||
# Enhanced decorator-style registration
|
||||
if architecture and module_path:
|
||||
categories = category if isinstance(category, list) else [category]
|
||||
|
||||
# Register main entry
|
||||
arch_key = architecture
|
||||
cls._enhanced_models[arch_key] = {
|
||||
"class_name": model_cls.__name__,
|
||||
"module_path": module_path,
|
||||
"category": primary_use or categories[0],
|
||||
"class": model_cls,
|
||||
}
|
||||
|
||||
# Register category-specific entries for multi-category models
|
||||
if len(categories) > 1:
|
||||
for cat in categories:
|
||||
key = f"{arch_key}_{cat.value}"
|
||||
cls._enhanced_models[key] = {
|
||||
"class_name": model_cls.__name__,
|
||||
"module_path": module_path,
|
||||
"category": cat,
|
||||
"primary_use": primary_use or categories[0],
|
||||
"class": model_cls,
|
||||
}
|
||||
return model_cls
|
||||
|
||||
if model_class is not None:
|
||||
return _register(model_class)
|
||||
else:
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def register_pretrained_model(cls, pretrained_model):
|
||||
@@ -50,11 +279,6 @@ class ModelRegistry:
|
||||
|
||||
return pretrained_model
|
||||
|
||||
@classmethod
|
||||
def get_pretrain_cls(cls, architectures: str):
|
||||
"""get_pretrain_cls"""
|
||||
return cls._arch_to_pretrained_model_cls[architectures]
|
||||
|
||||
@classmethod
|
||||
def get_class(cls, name):
|
||||
"""get model class"""
|
||||
@@ -62,12 +286,61 @@ class ModelRegistry:
|
||||
raise ValueError(f"Model '{name}' is not registered!")
|
||||
return cls._arch_to_model_cls[name]
|
||||
|
||||
@classmethod
|
||||
def get_pretrain_cls(cls, architectures: str):
|
||||
"""get_pretrain_cls"""
|
||||
return cls._arch_to_pretrained_model_cls[architectures]
|
||||
|
||||
@classmethod
|
||||
def get_supported_archs(cls):
|
||||
assert len(cls._arch_to_model_cls) >= len(
|
||||
cls._arch_to_pretrained_model_cls
|
||||
), "model class num is more than pretrained model registry num"
|
||||
return [key for key in cls._arch_to_model_cls.keys()]
|
||||
traditional_archs = list(cls._arch_to_model_cls.keys())
|
||||
enhanced_archs = list(cls._enhanced_models.keys())
|
||||
return traditional_archs + enhanced_archs
|
||||
|
||||
def resolve_model_cls(self, architectures: Union[str, List[str]]) -> Tuple[Type[nn.Layer], str]:
|
||||
"""Resolve model class"""
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return model_cls, arch
|
||||
|
||||
raise ValueError(f"Cannot find supported model: {architectures}")
|
||||
|
||||
def is_multimodal_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
|
||||
"""Check if it's a multimodal model"""
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
|
||||
for arch in architectures:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return model_info.is_multimodal
|
||||
return False
|
||||
|
||||
def is_text_generation_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
|
||||
"""Check if it's a text generation model"""
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
|
||||
for arch in architectures:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return model_info.is_text_generation
|
||||
return False
|
||||
|
||||
def is_pooling_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool:
|
||||
"""Check if it's a pooling model"""
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
|
||||
for arch in architectures:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return model_info.is_pooling
|
||||
return False
|
||||
|
||||
|
||||
class ModelForCasualLM(nn.Layer, ABC):
|
||||
@@ -88,7 +361,6 @@ class ModelForCasualLM(nn.Layer, ABC):
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||
"""
|
||||
Load model parameters from a given state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||
A dictionary containing model parameters, where keys are parameter names
|
||||
@@ -105,12 +377,10 @@ class ModelForCasualLM(nn.Layer, ABC):
|
||||
):
|
||||
"""
|
||||
Defines the forward pass of the model for generating text.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor, optional): The input token ids to the model.
|
||||
pos_emb (Tensor, optional): position Embeddings for model.
|
||||
**model_kwargs: Additional keyword arguments for the model.
|
||||
|
||||
Returns:
|
||||
Tensor or list of Tensors: Generated tokens or decoded outputs.
|
||||
"""
|
||||
|
@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
|
||||
)
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Layer):
|
||||
@@ -282,6 +286,12 @@ class Qwen2Model(nn.Layer):
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Qwen2ForCausalLM",
|
||||
module_path="qwen2",
|
||||
category=[ModelCategory.TEXT_GENERATION, ModelCategory.EMBEDDING],
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
Qwen2ForCausalLM
|
||||
|
@@ -33,7 +33,11 @@ from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
@@ -157,6 +161,12 @@ class Qwen2_5_VLModel(nn.Layer):
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Qwen2_5_VLForConditionalGeneration",
|
||||
module_path="qwen2_5_vl.qwen2_5_vl",
|
||||
category=ModelCategory.MULTIMODAL,
|
||||
primary_use=ModelCategory.MULTIMODAL,
|
||||
)
|
||||
class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||
"""
|
||||
Qwen2_5_VLForConditionalGeneration
|
||||
|
@@ -34,8 +34,13 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP
|
||||
from fastdeploy.transformer_utils.config import get_pooling_config
|
||||
|
||||
|
||||
class Qwen3MLP(Qwen2MLP):
|
||||
@@ -218,6 +223,12 @@ class Qwen3Model(nn.Layer):
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Qwen3ForCausalLM",
|
||||
module_path="qwen3",
|
||||
category=[ModelCategory.TEXT_GENERATION],
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
Qwen3ForCausalLM
|
||||
@@ -260,6 +271,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
process_weights_after_loading,
|
||||
)
|
||||
|
||||
is_pooling_model = hasattr(self, "is_pooling_model") and self.is_pooling_model
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -270,8 +283,18 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
model_path = self.fd_config.model_config.model
|
||||
revision = self.fd_config.model_config.revision
|
||||
if is_pooling_model and get_pooling_config(model_path, revision):
|
||||
params_dict = {
|
||||
param_name[6:] if param_name.startswith("model.") else param_name: param
|
||||
for param_name, param in params_dict.items()
|
||||
}
|
||||
|
||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in loaded_weight_name:
|
||||
@@ -282,6 +305,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
param = params_dict[model_param_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
model_param_name = loaded_weight_name
|
||||
@@ -290,10 +314,11 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
param = params_dict[model_param_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
if self.tie_word_embeddings and not is_pooling_model:
|
||||
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})
|
||||
|
||||
@paddle.no_grad()
|
||||
|
@@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.linear import (
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
|
||||
|
||||
|
||||
@@ -316,6 +320,12 @@ class Qwen3MoeModel(nn.Layer):
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="Qwen3MoeForCausalLM",
|
||||
module_path="qwen3moe",
|
||||
category=ModelCategory.TEXT_GENERATION,
|
||||
primary_use=ModelCategory.TEXT_GENERATION,
|
||||
)
|
||||
class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
Qwen3MoeForCausalLM
|
||||
|
Reference in New Issue
Block a user