【Feature】add fd plugins && rm model_classes (#3123)

* add fd plugins && rm model_classed

* fix reviews

* add docs

* fix

* fix unitest ci
This commit is contained in:
gaoziyuan
2025-08-04 10:53:20 +08:00
committed by GitHub
parent 1582814905
commit 4021d66ea5
25 changed files with 524 additions and 59 deletions

View File

@@ -103,6 +103,13 @@ jobs:
python -m pip install coverage python -m pip install coverage
python -m pip install diff-cover python -m pip install diff-cover
python -m pip install ${fd_wheel_url} python -m pip install ${fd_wheel_url}
if [ -d "test/plugins" ]; then
cd test/plugins
python setup.py install
cd ../..
else
echo "Warning: test/plugins directory not found, skipping setup.py install"
fi
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
TEST_EXIT_CODE=0 TEST_EXIT_CODE=0

85
docs/features/plugins.md Normal file
View File

@@ -0,0 +1,85 @@
# FastDeploy Plugin Mechanism Documentation
FastDeploy supports a plugin mechanism that allows users to extend functionality without modifying the core code. Plugins are automatically discovered and loaded through Python's `entry_points` mechanism.
## How Plugins Work
Plugins are essentially registration functions that are automatically called when FastDeploy starts. The system uses the `load_plugins_by_group` function to ensure that all processes (including child processes in distributed training scenarios) have loaded the required plugins before official operations begin.
## Plugin Discovery Mechanism
FastDeploy uses Python's `entry_points` mechanism to discover and load plugins. Developers need to register their plugins in the specified entry point group in their project.
### Example: Creating a Plugin
#### 1. How Plugin Work
Assuming you have a custom model class `MyModelForCasualLM` and a pretrained class `MyPretrainedModel`, you can write the following registration function:
```python
# File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py
from fastdeploy.model_registry import ModelRegistry
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
def register():
if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model_class(MyModelForCasualLM)
ModelRegistry.register_pretrained_model(MyPretrainedModel)
```
#### 2. Register Plugin in `setup.py`
```python
# setup.py
from setuptools import setup
setup(
name="fastdeploy-plugins",
version="0.1",
packages=["fd_add_dummy_model"],
entry_points={
"fastdeploy.model_register_plugins": [
"fd_add_dummy_model = fd_add_dummy_model:register",
],
},
)
```
## Plugin Structure
Plugins consist of three components:
| Component | Description |
|-----------|-------------|
| **Plugin Group** | The functional group to which the plugin belongs, for example:<br> - `fastdeploy.model_register_plugins`: for model registration<br> - `fastdeploy.model_runner_plugins`: for model runner registration<br> Users can customize groups as needed. |
| **Plugin Name** | The unique identifier for each plugin (e.g., `fd_add_dummy_model`), which can be controlled via the `FD_PLUGINS` environment variable to determine whether to load the plugin. |
| **Plugin Value** | Format is `module_name:function_name`, pointing to the entry function that executes the registration logic. |
## Controlling Plugin Loading Behavior
By default, FastDeploy loads all registered plugins. To load only specific plugins, you can set the environment variable:
```bash
export FD_PLUGINS=fastdeploy-plugins
```
Multiple plugin names can be separated by commas:
```bash
export FD_PLUGINS=plugin_a,plugin_b
```
## Reference Example
Please refer to the example plugin implementation in the project directory:
```
./test/plugins/
```
It contains a complete plugin structure and `setup.py` configuration example.
## Summary
Through the plugin mechanism, users can easily add custom models or functional modules to FastDeploy without modifying the core source code. This not only enhances system extensibility but also facilitates third-party developers in extending functionality.
For further plugin development, please refer to the `model_registry` and `plugin_loader` modules in the FastDeploy source code.

View File

@@ -0,0 +1,85 @@
# FastDeploy 插件机制说明文档
FastDeploy 支持插件机制,允许用户在不修改核心代码的前提下扩展功能。插件通过 Python 的 `entry_points` 机制实现自动发现与加载。
## 插件工作原理
插件本质上是在 FastDeploy 启动时被自动调用的注册函数。系统使用 `load_plugins_by_group` 函数确保所有进程(包括分布式训练场景下的子进程)在正式运行前都已加载所需的插件。
## 插件发现机制
FastDeploy 利用 Python 的 `entry_points` 机制来发现并加载插件。开发者需在自己的项目中将插件注册到指定的 entry point 组中。
### 示例:创建一个插件
#### 1. 编写插件逻辑
假设你有一个自定义模型类 `MyModelForCasualLM` 和预训练类 `MyPretrainedModel`,你可以编写如下注册函数:
```python
# 文件fd_add_dummy_model/__init__.py
from fastdeploy.model_registry import ModelRegistry
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
def register():
if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model_class(MyModelForCasualLM)
ModelRegistry.register_pretrained_model(MyPretrainedModel)
```
#### 2. 注册插件到 `setup.py`
```python
# setup.py
from setuptools import setup
setup(
name="fastdeploy-plugins",
version="0.1",
packages=["fd_add_dummy_model"],
entry_points={
"fastdeploy.model_register_plugins": [
"fd_add_dummy_model = fd_add_dummy_model:register",
],
},
)
```
## 插件结构说明
插件由三部分组成:
| 组件 | 说明 |
|------|------|
| **插件组Group** | 插件所属的功能分组,例如:<br> - `fastdeploy.model_register_plugins`: 用于注册模型<br> - `fastdeploy.model_runner_plugins`: 用于注册模型运行器<br> 用户可根据需要自定义分组。 |
| **插件名Name** | 每个插件的唯一标识名(如 `fd_add_dummy_model`),可通过环境变量 `FD_PLUGINS` 控制是否加载该插件。 |
| **插件值Value** | 格式为 `模块名:函数名`,指向实际执行注册逻辑的入口函数。 |
## 控制插件加载行为
默认情况下FastDeploy 会加载所有已注册的插件。若只想加载特定插件,可以设置环境变量:
```bash
export FD_PLUGINS=fastdeploy-plugins
```
多个插件名之间可以用逗号分隔:
```bash
export FD_PLUGINS=plugin_a,plugin_b
```
## 参考示例
请参见项目目录下的示例插件实现:
```
./test/plugins/
```
其中包含完整的插件结构和 `setup.py` 配置示例。
## 总结
通过插件机制,用户可以轻松地为 FastDeploy 添加自定义模型或功能模块,而无需修改核心源码。这不仅提升了系统的可扩展性,也方便了第三方开发者进行功能拓展。
如需进一步开发插件,请参考 FastDeploy 源码中的 `model_registry``plugin_loader` 模块。

View File

@@ -22,11 +22,10 @@ import sys
os.environ["GLOG_minloglevel"] = "2" os.environ["GLOG_minloglevel"] = "2"
# suppress log from aistudio # suppress log from aistudio
os.environ["AISTUDIO_LOG"] = "critical" os.environ["AISTUDIO_LOG"] = "critical"
import typing
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM from fastdeploy.entrypoints.llm import LLM
from fastdeploy.utils import version
__all__ = ["LLM", "SamplingParams", "version"]
try: try:
import use_triton_in_paddle import use_triton_in_paddle
@@ -86,3 +85,27 @@ def _patch_fastsafetensors():
_patch_fastsafetensors() _patch_fastsafetensors()
MODULE_ATTRS = {"ModelRegistry": ".model_executor.models.model_base:ModelRegistry", "version": ".utils:version"}
if typing.TYPE_CHECKING:
from fastdeploy.model_executor.models.model_base import ModelRegistry
else:
def __getattr__(name: str) -> typing.Any:
from importlib import import_module
if name in MODULE_ATTRS:
try:
module_name, attr_name = MODULE_ATTRS[name].split(":")
module = import_module(module_name, __package__)
return getattr(module, attr_name)
except ModuleNotFoundError:
print(f"Module {MODULE_ATTRS[name]} not found.")
else:
print(f"module {__package__} has no attribute {name}")
__all__ = ["LLM", "SamplingParams", "ModelRegistry", "version"]

View File

@@ -80,6 +80,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"), "EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
# enable kv cache block scheduler v1 (no need for kv_cache_ratio) # enable kv cache block scheduler v1 (no need for kv_cache_ratio)
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")), "ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
# Whether to use PLUGINS.
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
} }

View File

@@ -24,7 +24,6 @@ from fastdeploy.model_executor.load_weight_utils import (
measure_time, measure_time,
) )
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.model_loader.utils import get_pretrain_cls
from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
@@ -52,7 +51,7 @@ class DefaultModelLoader(BaseModelLoader):
@measure_time @measure_time
def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None: def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None:
model_class = get_pretrain_cls(architectures) model_class = ModelRegistry.get_pretrain_cls(architectures)
state_dict = load_composite_checkpoint( state_dict = load_composite_checkpoint(
fd_config.model_config.model, fd_config.model_config.model,
model_class, model_class,

View File

@@ -1,43 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from paddleformers.transformers import PretrainedModel
from fastdeploy.model_executor.models.deepseek_v3 import DeepSeekV3PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPPretrainedModel
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
Ernie4_5_VLPretrainedModel,
)
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
MODEL_CLASSES = {
"Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel,
"Ernie4_5_MTPForCausalLM": Ernie4_5_MTPPretrainedModel,
"Qwen2ForCausalLM": Qwen2PretrainedModel,
"Qwen3ForCausalLM": Qwen3PretrainedModel,
"Qwen3MoeForCausalLM": Qwen3MoePretrainedModel,
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel,
"DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLPretrainedModel,
}
def get_pretrain_cls(architectures: str) -> PretrainedModel:
"""get_pretrain_cls"""
return MODEL_CLASSES[architectures]

View File

@@ -19,6 +19,8 @@ import inspect
import os import os
from pathlib import Path from pathlib import Path
from paddleformers.transformers import PretrainedModel
from .model_base import ModelForCasualLM, ModelRegistry from .model_base import ModelForCasualLM, ModelRegistry
@@ -44,7 +46,14 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
for attr_name in dir(module): for attr_name in dir(module):
attr = getattr(module, attr_name) attr = getattr(module, attr_name)
if inspect.isclass(attr) and issubclass(attr, ModelForCasualLM) and attr is not ModelForCasualLM: if inspect.isclass(attr) and issubclass(attr, ModelForCasualLM) and attr is not ModelForCasualLM:
ModelRegistry.register(attr) ModelRegistry.register_model_class(attr)
if (
inspect.isclass(attr)
and issubclass(attr, PretrainedModel)
and attr is not PretrainedModel
and hasattr(attr, "arch_name")
):
ModelRegistry.register_pretrained_model(attr)
except ImportError: except ImportError:
raise ImportError(f"{module_file=} import error") raise ImportError(f"{module_file=} import error")

View File

@@ -673,6 +673,10 @@ class DeepSeekV3PretrainedModel(PretrainedModel):
""" """
return None return None
@classmethod
def arch_name(self):
return "DeepseekV3ForCausalLM"
@classmethod @classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True): def _get_tensor_parallel_mappings(cls, config, is_split=True):

View File

@@ -460,9 +460,9 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
return "Ernie4_5_ForCausalLM" return "Ernie4_5_ForCausalLM"
class Ernie4_5_PretrainedModel(PretrainedModel): class Ernie4_5_MoePretrainedModel(PretrainedModel):
""" """
Ernie4_5_PretrainedModel Ernie4_5_MoePretrainedModel
""" """
config_class = FDConfig config_class = FDConfig
@@ -473,6 +473,10 @@ class Ernie4_5_PretrainedModel(PretrainedModel):
""" """
return None return None
@classmethod
def arch_name(self):
return "Ernie4_5_MoeForCausalLM"
weight_infos = [ weight_infos = [
WeightMeta( WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight", f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
@@ -594,3 +598,16 @@ class Ernie4_5_PretrainedModel(PretrainedModel):
config.prefix_name, config.prefix_name,
) )
return mappings return mappings
class Ernie4_5_PretrainedModel(Ernie4_5_MoePretrainedModel):
"""
Ernie4_5_PretrainedModel
"""
@classmethod
def arch_name(self):
"""
Model Architecture Name
"""
return "Ernie4_5_ForCausalLM"

View File

@@ -46,6 +46,10 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
""" """
return None return None
@classmethod
def arch_name(self):
return "Ernie4_5_MTPForCausalLM"
@classmethod @classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True): def _get_tensor_parallel_mappings(cls, config, is_split=True):
""" """

View File

@@ -605,7 +605,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
class Ernie4_5_VLPretrainedModel(PretrainedModel): class Ernie4_5_VLPretrainedModel(PretrainedModel):
""" """
Ernie4_5_PretrainedModel Ernie4_5_MoePretrainedModel
""" """
config_class = FDConfig config_class = FDConfig
@@ -616,6 +616,10 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
""" """
return None return None
@classmethod
def arch_name(self):
return "Ernie4_5_VLMoeForConditionalGeneration"
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta from fastdeploy.model_executor.models.utils import WeightMeta

View File

@@ -20,6 +20,7 @@ from typing import Dict, Union
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddleformers.transformers import PretrainedModel
class ModelRegistry: class ModelRegistry:
@@ -27,21 +28,46 @@ class ModelRegistry:
Used to register and retrieve model classes. Used to register and retrieve model classes.
""" """
_registry = {} _arch_to_model_cls = {}
_arch_to_pretrained_model_cls = {}
@classmethod @classmethod
def register(cls, model_class): def register_model_class(cls, model_class):
"""register model class""" """register model class"""
if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM: if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM:
cls._registry[model_class.name()] = model_class cls._arch_to_model_cls[model_class.name()] = model_class
return model_class return model_class
@classmethod
def register_pretrained_model(cls, pretrained_model):
"""register pretrained model class"""
if (
issubclass(pretrained_model, PretrainedModel)
and pretrained_model is not PretrainedModel
and hasattr(pretrained_model, "arch_name")
):
cls._arch_to_pretrained_model_cls[pretrained_model.arch_name()] = pretrained_model
return pretrained_model
@classmethod
def get_pretrain_cls(cls, architectures: str):
"""get_pretrain_cls"""
return cls._arch_to_pretrained_model_cls[architectures]
@classmethod @classmethod
def get_class(cls, name): def get_class(cls, name):
"""get model class""" """get model class"""
if name not in cls._registry: if name not in cls._arch_to_model_cls:
raise ValueError(f"Model '{name}' is not registered!") raise ValueError(f"Model '{name}' is not registered!")
return cls._registry[name] return cls._arch_to_model_cls[name]
@classmethod
def get_supported_archs(cls):
assert len(cls._arch_to_model_cls) == len(
cls._arch_to_model_cls
), "model class / pretrained model registry num is not same"
return [key for key in cls._arch_to_model_cls.keys()]
class ModelForCasualLM(nn.Layer, ABC): class ModelForCasualLM(nn.Layer, ABC):

View File

@@ -355,6 +355,10 @@ class Qwen2PretrainedModel(PretrainedModel):
""" """
return None return None
@classmethod
def arch_name(self):
return "Qwen2ForCausalLM"
@classmethod @classmethod
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True): def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):

View File

@@ -334,6 +334,10 @@ class Qwen3PretrainedModel(PretrainedModel):
""" """
return None return None
@classmethod
def arch_name(self):
return "Qwen3ForCausalLM"
@classmethod @classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True): def _get_tensor_parallel_mappings(cls, config, is_split=True):

View File

@@ -324,6 +324,10 @@ class Qwen3MoePretrainedModel(PretrainedModel):
""" """
return None return None
@classmethod
def arch_name(self):
return "Qwen3MoeForCausalLM"
@classmethod @classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True): def _get_tensor_parallel_mappings(cls, config, is_split=True):
# TODO not support TP split now, next PR will support TP. # TODO not support TP split now, next PR will support TP.

View File

@@ -0,0 +1,20 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .model_register import load_model_register_plugins
from .model_runner import load_model_runner_plugins
__all__ = ["load_model_register_plugins", "load_model_runner_plugins"]

View File

@@ -0,0 +1,33 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
# make sure one process only loads plugins once
PLUGINS_GROUP = "fastdeploy.model_register_plugins"
def load_model_register_plugins():
"""load_model_runner_plugins"""
global plugins_loaded
if plugins_loaded:
return
plugins_loaded = True
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
# general plugins, we only need to execute the loaded functions
for func in plugins.values():
func()

View File

@@ -0,0 +1,32 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
# use for modle runner
PLUGINS_GROUP = "fastdeploy.model_runner_plugins"
def load_model_runner_plugins():
"""load_model_runner_plugins"""
global plugins_loaded
if plugins_loaded:
return
plugins_loaded = True
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
return next(iter(plugins.values()))

View File

@@ -0,0 +1,61 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any, Callable
from fastdeploy import envs
from fastdeploy.utils import llm_logger as logger
plugins_loaded = False
def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
allowed_plugins = envs.FD_PLUGINS
discovered_plugins = entry_points(group=group)
if len(discovered_plugins) == 0:
logger.info("No plugins for group %s found.", group)
return {}
logger.info("Available plugins for group %s:", group)
for plugin in discovered_plugins:
logger.info("- %s -> %s", plugin.name, plugin.value)
if allowed_plugins is None:
logger.info(
"All plugins in this group will be loaded. " "You can set `FD_PLUGINS` to control which plugins to load."
)
plugins = dict[str, Callable[[], Any]]()
for plugin in discovered_plugins:
if allowed_plugins is None or plugin.name in allowed_plugins:
if allowed_plugins is not None:
logger.info("Loading plugin %s", plugin.name)
try:
func = plugin.load()
plugins[plugin.name] = func
except Exception:
logger.exception("Failed to load plugin %s", plugin.name)
return plugins

View File

@@ -22,7 +22,7 @@ from paddle import nn
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.ernie4_5_moe import ( from fastdeploy.model_executor.models.ernie4_5_moe import (
Ernie4_5_MoeForCausalLM, Ernie4_5_MoeForCausalLM,
Ernie4_5_PretrainedModel, Ernie4_5_MoePretrainedModel,
) )
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import ( from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
Ernie4_5_VLMoeForConditionalGeneration, Ernie4_5_VLMoeForConditionalGeneration,
@@ -126,7 +126,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
Ernie4_5_MoeForCausalLMRL Ernie4_5_MoeForCausalLMRL
""" """
_get_tensor_parallel_mappings = Ernie4_5_PretrainedModel._get_tensor_parallel_mappings _get_tensor_parallel_mappings = Ernie4_5_MoePretrainedModel._get_tensor_parallel_mappings
def __init__(self, fd_config: FDConfig): def __init__(self, fd_config: FDConfig):
""" """

View File

@@ -748,4 +748,7 @@ def run_worker_proc() -> None:
if __name__ == "__main__": if __name__ == "__main__":
from fastdeploy.plugins.model_register import load_model_register_plugins
load_model_register_plugins()
run_worker_proc() run_worker_proc()

View File

@@ -0,0 +1,35 @@
from paddleformers.transformers import PretrainedModel
from fastdeploy import ModelRegistry
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
class MyPretrainedModel(PretrainedModel):
@classmethod
def arch_names(cls):
return "MyModelForCasualLM"
class MyModelForCasualLM(ModelForCasualLM):
def __init__(self, fd_config):
"""
Args:
fd_config : Configurations for the LLM model.
"""
super().__init__(fd_config)
print("init done")
@classmethod
def name(cls):
return "MyModelForCasualLM"
def compute_logits(self, logits):
logits[:, 0] += 1.0
return logits
def register():
if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model_class(MyModelForCasualLM)
ModelRegistry.register_pretrained_model(MyPretrainedModel)

15
test/plugins/setup.py Normal file
View File

@@ -0,0 +1,15 @@
from setuptools import setup
setup(
name="fastdeploy-plugins",
version="0.1",
packages=["fd_add_dummy_model"],
entry_points={
"fastdeploy.model_register_plugins": [
"fd_add_dummy_model = fd_add_dummy_model:register",
],
# 'fastdeploy.model_runner_plugins': [
# "model_runner = model_runner:get_runner"
# ]
},
)

View File

@@ -0,0 +1,32 @@
import unittest
from fastdeploy import ModelRegistry
from fastdeploy.plugins import load_model_register_plugins
class TestModelRegistryPlugins(unittest.TestCase):
def test_plugin_registers_one_architecture(self):
"""Test that loading plugins registers exactly one new architecture."""
initial_archs = set(ModelRegistry.get_supported_archs())
print("Supported architectures before loading plugins:", sorted(initial_archs))
# Load plugins
load_model_register_plugins()
final_archs = set(ModelRegistry.get_supported_archs())
print("Supported architectures after loading plugins:", sorted(final_archs))
added_archs = final_archs - initial_archs
added_count = len(added_archs)
# verify
self.assertEqual(
added_count,
1,
f"Expected exactly 1 new architecture to be registered by plugins, "
f"but {added_count} were added: {sorted(added_archs)}",
)
if __name__ == "__main__":
unittest.main()