mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
【Feature】add fd plugins && rm model_classes (#3123)
* add fd plugins && rm model_classed * fix reviews * add docs * fix * fix unitest ci
This commit is contained in:
7
.github/workflows/_unit_test_coverage.yml
vendored
7
.github/workflows/_unit_test_coverage.yml
vendored
@@ -103,6 +103,13 @@ jobs:
|
||||
python -m pip install coverage
|
||||
python -m pip install diff-cover
|
||||
python -m pip install ${fd_wheel_url}
|
||||
if [ -d "test/plugins" ]; then
|
||||
cd test/plugins
|
||||
python setup.py install
|
||||
cd ../..
|
||||
else
|
||||
echo "Warning: test/plugins directory not found, skipping setup.py install"
|
||||
fi
|
||||
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
|
||||
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
|
||||
TEST_EXIT_CODE=0
|
||||
|
85
docs/features/plugins.md
Normal file
85
docs/features/plugins.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# FastDeploy Plugin Mechanism Documentation
|
||||
|
||||
FastDeploy supports a plugin mechanism that allows users to extend functionality without modifying the core code. Plugins are automatically discovered and loaded through Python's `entry_points` mechanism.
|
||||
|
||||
## How Plugins Work
|
||||
|
||||
Plugins are essentially registration functions that are automatically called when FastDeploy starts. The system uses the `load_plugins_by_group` function to ensure that all processes (including child processes in distributed training scenarios) have loaded the required plugins before official operations begin.
|
||||
|
||||
## Plugin Discovery Mechanism
|
||||
|
||||
FastDeploy uses Python's `entry_points` mechanism to discover and load plugins. Developers need to register their plugins in the specified entry point group in their project.
|
||||
|
||||
### Example: Creating a Plugin
|
||||
|
||||
#### 1. How Plugin Work
|
||||
|
||||
Assuming you have a custom model class `MyModelForCasualLM` and a pretrained class `MyPretrainedModel`, you can write the following registration function:
|
||||
|
||||
```python
|
||||
# File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py
|
||||
from fastdeploy.model_registry import ModelRegistry
|
||||
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
|
||||
|
||||
def register():
|
||||
if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model_class(MyModelForCasualLM)
|
||||
ModelRegistry.register_pretrained_model(MyPretrainedModel)
|
||||
```
|
||||
|
||||
#### 2. Register Plugin in `setup.py`
|
||||
|
||||
```python
|
||||
# setup.py
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="fastdeploy-plugins",
|
||||
version="0.1",
|
||||
packages=["fd_add_dummy_model"],
|
||||
entry_points={
|
||||
"fastdeploy.model_register_plugins": [
|
||||
"fd_add_dummy_model = fd_add_dummy_model:register",
|
||||
],
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## Plugin Structure
|
||||
|
||||
Plugins consist of three components:
|
||||
|
||||
| Component | Description |
|
||||
|-----------|-------------|
|
||||
| **Plugin Group** | The functional group to which the plugin belongs, for example:<br> - `fastdeploy.model_register_plugins`: for model registration<br> - `fastdeploy.model_runner_plugins`: for model runner registration<br> Users can customize groups as needed. |
|
||||
| **Plugin Name** | The unique identifier for each plugin (e.g., `fd_add_dummy_model`), which can be controlled via the `FD_PLUGINS` environment variable to determine whether to load the plugin. |
|
||||
| **Plugin Value** | Format is `module_name:function_name`, pointing to the entry function that executes the registration logic. |
|
||||
|
||||
## Controlling Plugin Loading Behavior
|
||||
|
||||
By default, FastDeploy loads all registered plugins. To load only specific plugins, you can set the environment variable:
|
||||
|
||||
```bash
|
||||
export FD_PLUGINS=fastdeploy-plugins
|
||||
```
|
||||
|
||||
Multiple plugin names can be separated by commas:
|
||||
|
||||
```bash
|
||||
export FD_PLUGINS=plugin_a,plugin_b
|
||||
```
|
||||
|
||||
## Reference Example
|
||||
|
||||
Please refer to the example plugin implementation in the project directory:
|
||||
```
|
||||
./test/plugins/
|
||||
```
|
||||
|
||||
It contains a complete plugin structure and `setup.py` configuration example.
|
||||
|
||||
## Summary
|
||||
|
||||
Through the plugin mechanism, users can easily add custom models or functional modules to FastDeploy without modifying the core source code. This not only enhances system extensibility but also facilitates third-party developers in extending functionality.
|
||||
|
||||
For further plugin development, please refer to the `model_registry` and `plugin_loader` modules in the FastDeploy source code.
|
85
docs/zh/features/plugins.md
Normal file
85
docs/zh/features/plugins.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# FastDeploy 插件机制说明文档
|
||||
|
||||
FastDeploy 支持插件机制,允许用户在不修改核心代码的前提下扩展功能。插件通过 Python 的 `entry_points` 机制实现自动发现与加载。
|
||||
|
||||
## 插件工作原理
|
||||
|
||||
插件本质上是在 FastDeploy 启动时被自动调用的注册函数。系统使用 `load_plugins_by_group` 函数确保所有进程(包括分布式训练场景下的子进程)在正式运行前都已加载所需的插件。
|
||||
|
||||
## 插件发现机制
|
||||
|
||||
FastDeploy 利用 Python 的 `entry_points` 机制来发现并加载插件。开发者需在自己的项目中将插件注册到指定的 entry point 组中。
|
||||
|
||||
### 示例:创建一个插件
|
||||
|
||||
#### 1. 编写插件逻辑
|
||||
|
||||
假设你有一个自定义模型类 `MyModelForCasualLM` 和预训练类 `MyPretrainedModel`,你可以编写如下注册函数:
|
||||
|
||||
```python
|
||||
# 文件:fd_add_dummy_model/__init__.py
|
||||
from fastdeploy.model_registry import ModelRegistry
|
||||
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
|
||||
|
||||
def register():
|
||||
if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model_class(MyModelForCasualLM)
|
||||
ModelRegistry.register_pretrained_model(MyPretrainedModel)
|
||||
```
|
||||
|
||||
#### 2. 注册插件到 `setup.py`
|
||||
|
||||
```python
|
||||
# setup.py
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="fastdeploy-plugins",
|
||||
version="0.1",
|
||||
packages=["fd_add_dummy_model"],
|
||||
entry_points={
|
||||
"fastdeploy.model_register_plugins": [
|
||||
"fd_add_dummy_model = fd_add_dummy_model:register",
|
||||
],
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## 插件结构说明
|
||||
|
||||
插件由三部分组成:
|
||||
|
||||
| 组件 | 说明 |
|
||||
|------|------|
|
||||
| **插件组(Group)** | 插件所属的功能分组,例如:<br> - `fastdeploy.model_register_plugins`: 用于注册模型<br> - `fastdeploy.model_runner_plugins`: 用于注册模型运行器<br> 用户可根据需要自定义分组。 |
|
||||
| **插件名(Name)** | 每个插件的唯一标识名(如 `fd_add_dummy_model`),可通过环境变量 `FD_PLUGINS` 控制是否加载该插件。 |
|
||||
| **插件值(Value)** | 格式为 `模块名:函数名`,指向实际执行注册逻辑的入口函数。 |
|
||||
|
||||
## 控制插件加载行为
|
||||
|
||||
默认情况下,FastDeploy 会加载所有已注册的插件。若只想加载特定插件,可以设置环境变量:
|
||||
|
||||
```bash
|
||||
export FD_PLUGINS=fastdeploy-plugins
|
||||
```
|
||||
|
||||
多个插件名之间可以用逗号分隔:
|
||||
|
||||
```bash
|
||||
export FD_PLUGINS=plugin_a,plugin_b
|
||||
```
|
||||
|
||||
## 参考示例
|
||||
|
||||
请参见项目目录下的示例插件实现:
|
||||
```
|
||||
./test/plugins/
|
||||
```
|
||||
|
||||
其中包含完整的插件结构和 `setup.py` 配置示例。
|
||||
|
||||
## 总结
|
||||
|
||||
通过插件机制,用户可以轻松地为 FastDeploy 添加自定义模型或功能模块,而无需修改核心源码。这不仅提升了系统的可扩展性,也方便了第三方开发者进行功能拓展。
|
||||
|
||||
如需进一步开发插件,请参考 FastDeploy 源码中的 `model_registry` 和 `plugin_loader` 模块。
|
@@ -22,11 +22,10 @@ import sys
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
# suppress log from aistudio
|
||||
os.environ["AISTUDIO_LOG"] = "critical"
|
||||
import typing
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
from fastdeploy.utils import version
|
||||
|
||||
__all__ = ["LLM", "SamplingParams", "version"]
|
||||
|
||||
try:
|
||||
import use_triton_in_paddle
|
||||
@@ -86,3 +85,27 @@ def _patch_fastsafetensors():
|
||||
|
||||
|
||||
_patch_fastsafetensors()
|
||||
|
||||
|
||||
MODULE_ATTRS = {"ModelRegistry": ".model_executor.models.model_base:ModelRegistry", "version": ".utils:version"}
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||
else:
|
||||
|
||||
def __getattr__(name: str) -> typing.Any:
|
||||
from importlib import import_module
|
||||
|
||||
if name in MODULE_ATTRS:
|
||||
try:
|
||||
module_name, attr_name = MODULE_ATTRS[name].split(":")
|
||||
module = import_module(module_name, __package__)
|
||||
return getattr(module, attr_name)
|
||||
except ModuleNotFoundError:
|
||||
print(f"Module {MODULE_ATTRS[name]} not found.")
|
||||
else:
|
||||
print(f"module {__package__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = ["LLM", "SamplingParams", "ModelRegistry", "version"]
|
||||
|
@@ -80,6 +80,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
|
||||
# enable kv cache block scheduler v1 (no need for kv_cache_ratio)
|
||||
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
|
||||
# Whether to use PLUGINS.
|
||||
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -24,7 +24,6 @@ from fastdeploy.model_executor.load_weight_utils import (
|
||||
measure_time,
|
||||
)
|
||||
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from fastdeploy.model_executor.model_loader.utils import get_pretrain_cls
|
||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
@@ -52,7 +51,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
@measure_time
|
||||
def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None:
|
||||
model_class = get_pretrain_cls(architectures)
|
||||
model_class = ModelRegistry.get_pretrain_cls(architectures)
|
||||
state_dict = load_composite_checkpoint(
|
||||
fd_config.model_config.model,
|
||||
model_class,
|
||||
|
@@ -1,43 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
|
||||
from fastdeploy.model_executor.models.deepseek_v3 import DeepSeekV3PretrainedModel
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_PretrainedModel
|
||||
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPPretrainedModel
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
|
||||
Ernie4_5_VLPretrainedModel,
|
||||
)
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel,
|
||||
"Ernie4_5_MTPForCausalLM": Ernie4_5_MTPPretrainedModel,
|
||||
"Qwen2ForCausalLM": Qwen2PretrainedModel,
|
||||
"Qwen3ForCausalLM": Qwen3PretrainedModel,
|
||||
"Qwen3MoeForCausalLM": Qwen3MoePretrainedModel,
|
||||
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel,
|
||||
"DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel,
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLPretrainedModel,
|
||||
}
|
||||
|
||||
|
||||
def get_pretrain_cls(architectures: str) -> PretrainedModel:
|
||||
"""get_pretrain_cls"""
|
||||
return MODEL_CLASSES[architectures]
|
@@ -19,6 +19,8 @@ import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
|
||||
from .model_base import ModelForCasualLM, ModelRegistry
|
||||
|
||||
|
||||
@@ -44,7 +46,14 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if inspect.isclass(attr) and issubclass(attr, ModelForCasualLM) and attr is not ModelForCasualLM:
|
||||
ModelRegistry.register(attr)
|
||||
ModelRegistry.register_model_class(attr)
|
||||
if (
|
||||
inspect.isclass(attr)
|
||||
and issubclass(attr, PretrainedModel)
|
||||
and attr is not PretrainedModel
|
||||
and hasattr(attr, "arch_name")
|
||||
):
|
||||
ModelRegistry.register_pretrained_model(attr)
|
||||
except ImportError:
|
||||
raise ImportError(f"{module_file=} import error")
|
||||
|
||||
|
@@ -673,6 +673,10 @@ class DeepSeekV3PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "DeepseekV3ForCausalLM"
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||
|
||||
|
@@ -460,9 +460,9 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
||||
return "Ernie4_5_ForCausalLM"
|
||||
|
||||
|
||||
class Ernie4_5_PretrainedModel(PretrainedModel):
|
||||
class Ernie4_5_MoePretrainedModel(PretrainedModel):
|
||||
"""
|
||||
Ernie4_5_PretrainedModel
|
||||
Ernie4_5_MoePretrainedModel
|
||||
"""
|
||||
|
||||
config_class = FDConfig
|
||||
@@ -473,6 +473,10 @@ class Ernie4_5_PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "Ernie4_5_MoeForCausalLM"
|
||||
|
||||
weight_infos = [
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
|
||||
@@ -594,3 +598,16 @@ class Ernie4_5_PretrainedModel(PretrainedModel):
|
||||
config.prefix_name,
|
||||
)
|
||||
return mappings
|
||||
|
||||
|
||||
class Ernie4_5_PretrainedModel(Ernie4_5_MoePretrainedModel):
|
||||
"""
|
||||
Ernie4_5_PretrainedModel
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
"""
|
||||
Model Architecture Name
|
||||
"""
|
||||
return "Ernie4_5_ForCausalLM"
|
||||
|
@@ -46,6 +46,10 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "Ernie4_5_MTPForCausalLM"
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||
"""
|
||||
|
@@ -605,7 +605,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
|
||||
class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
Ernie4_5_PretrainedModel
|
||||
Ernie4_5_MoePretrainedModel
|
||||
"""
|
||||
|
||||
config_class = FDConfig
|
||||
@@ -616,6 +616,10 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "Ernie4_5_VLMoeForConditionalGeneration"
|
||||
|
||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||
|
@@ -20,6 +20,7 @@ from typing import Dict, Union
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
@@ -27,21 +28,46 @@ class ModelRegistry:
|
||||
Used to register and retrieve model classes.
|
||||
"""
|
||||
|
||||
_registry = {}
|
||||
_arch_to_model_cls = {}
|
||||
_arch_to_pretrained_model_cls = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class):
|
||||
def register_model_class(cls, model_class):
|
||||
"""register model class"""
|
||||
if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM:
|
||||
cls._registry[model_class.name()] = model_class
|
||||
cls._arch_to_model_cls[model_class.name()] = model_class
|
||||
return model_class
|
||||
|
||||
@classmethod
|
||||
def register_pretrained_model(cls, pretrained_model):
|
||||
"""register pretrained model class"""
|
||||
if (
|
||||
issubclass(pretrained_model, PretrainedModel)
|
||||
and pretrained_model is not PretrainedModel
|
||||
and hasattr(pretrained_model, "arch_name")
|
||||
):
|
||||
cls._arch_to_pretrained_model_cls[pretrained_model.arch_name()] = pretrained_model
|
||||
|
||||
return pretrained_model
|
||||
|
||||
@classmethod
|
||||
def get_pretrain_cls(cls, architectures: str):
|
||||
"""get_pretrain_cls"""
|
||||
return cls._arch_to_pretrained_model_cls[architectures]
|
||||
|
||||
@classmethod
|
||||
def get_class(cls, name):
|
||||
"""get model class"""
|
||||
if name not in cls._registry:
|
||||
if name not in cls._arch_to_model_cls:
|
||||
raise ValueError(f"Model '{name}' is not registered!")
|
||||
return cls._registry[name]
|
||||
return cls._arch_to_model_cls[name]
|
||||
|
||||
@classmethod
|
||||
def get_supported_archs(cls):
|
||||
assert len(cls._arch_to_model_cls) == len(
|
||||
cls._arch_to_model_cls
|
||||
), "model class / pretrained model registry num is not same"
|
||||
return [key for key in cls._arch_to_model_cls.keys()]
|
||||
|
||||
|
||||
class ModelForCasualLM(nn.Layer, ABC):
|
||||
|
@@ -355,6 +355,10 @@ class Qwen2PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "Qwen2ForCausalLM"
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
|
||||
|
||||
|
@@ -334,6 +334,10 @@ class Qwen3PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "Qwen3ForCausalLM"
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||
|
||||
|
@@ -324,6 +324,10 @@ class Qwen3MoePretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "Qwen3MoeForCausalLM"
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||
# TODO not support TP split now, next PR will support TP.
|
||||
|
20
fastdeploy/plugins/__init__.py
Normal file
20
fastdeploy/plugins/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .model_register import load_model_register_plugins
|
||||
from .model_runner import load_model_runner_plugins
|
||||
|
||||
__all__ = ["load_model_register_plugins", "load_model_runner_plugins"]
|
33
fastdeploy/plugins/model_register/__init__.py
Normal file
33
fastdeploy/plugins/model_register/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
|
||||
|
||||
# make sure one process only loads plugins once
|
||||
PLUGINS_GROUP = "fastdeploy.model_register_plugins"
|
||||
|
||||
|
||||
def load_model_register_plugins():
|
||||
"""load_model_runner_plugins"""
|
||||
global plugins_loaded
|
||||
if plugins_loaded:
|
||||
return
|
||||
plugins_loaded = True
|
||||
|
||||
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
|
||||
# general plugins, we only need to execute the loaded functions
|
||||
for func in plugins.values():
|
||||
func()
|
32
fastdeploy/plugins/model_runner/__init__.py
Normal file
32
fastdeploy/plugins/model_runner/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
|
||||
|
||||
# use for modle runner
|
||||
PLUGINS_GROUP = "fastdeploy.model_runner_plugins"
|
||||
|
||||
|
||||
def load_model_runner_plugins():
|
||||
"""load_model_runner_plugins"""
|
||||
global plugins_loaded
|
||||
if plugins_loaded:
|
||||
return
|
||||
plugins_loaded = True
|
||||
|
||||
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
|
||||
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
|
||||
return next(iter(plugins.values()))
|
61
fastdeploy/plugins/utils.py
Normal file
61
fastdeploy/plugins/utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger as logger
|
||||
|
||||
plugins_loaded = False
|
||||
|
||||
|
||||
def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_metadata import entry_points
|
||||
else:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
allowed_plugins = envs.FD_PLUGINS
|
||||
|
||||
discovered_plugins = entry_points(group=group)
|
||||
if len(discovered_plugins) == 0:
|
||||
logger.info("No plugins for group %s found.", group)
|
||||
return {}
|
||||
|
||||
logger.info("Available plugins for group %s:", group)
|
||||
for plugin in discovered_plugins:
|
||||
logger.info("- %s -> %s", plugin.name, plugin.value)
|
||||
|
||||
if allowed_plugins is None:
|
||||
logger.info(
|
||||
"All plugins in this group will be loaded. " "You can set `FD_PLUGINS` to control which plugins to load."
|
||||
)
|
||||
|
||||
plugins = dict[str, Callable[[], Any]]()
|
||||
for plugin in discovered_plugins:
|
||||
if allowed_plugins is None or plugin.name in allowed_plugins:
|
||||
if allowed_plugins is not None:
|
||||
logger.info("Loading plugin %s", plugin.name)
|
||||
|
||||
try:
|
||||
func = plugin.load()
|
||||
plugins[plugin.name] = func
|
||||
except Exception:
|
||||
logger.exception("Failed to load plugin %s", plugin.name)
|
||||
|
||||
return plugins
|
@@ -22,7 +22,7 @@ from paddle import nn
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import (
|
||||
Ernie4_5_MoeForCausalLM,
|
||||
Ernie4_5_PretrainedModel,
|
||||
Ernie4_5_MoePretrainedModel,
|
||||
)
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
|
||||
Ernie4_5_VLMoeForConditionalGeneration,
|
||||
@@ -126,7 +126,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
|
||||
Ernie4_5_MoeForCausalLMRL
|
||||
"""
|
||||
|
||||
_get_tensor_parallel_mappings = Ernie4_5_PretrainedModel._get_tensor_parallel_mappings
|
||||
_get_tensor_parallel_mappings = Ernie4_5_MoePretrainedModel._get_tensor_parallel_mappings
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
|
@@ -748,4 +748,7 @@ def run_worker_proc() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from fastdeploy.plugins.model_register import load_model_register_plugins
|
||||
|
||||
load_model_register_plugins()
|
||||
run_worker_proc()
|
||||
|
35
test/plugins/fd_add_dummy_model/__init__.py
Normal file
35
test/plugins/fd_add_dummy_model/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
|
||||
from fastdeploy import ModelRegistry
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
|
||||
|
||||
class MyPretrainedModel(PretrainedModel):
|
||||
@classmethod
|
||||
def arch_names(cls):
|
||||
return "MyModelForCasualLM"
|
||||
|
||||
|
||||
class MyModelForCasualLM(ModelForCasualLM):
|
||||
|
||||
def __init__(self, fd_config):
|
||||
"""
|
||||
Args:
|
||||
fd_config : Configurations for the LLM model.
|
||||
"""
|
||||
super().__init__(fd_config)
|
||||
print("init done")
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return "MyModelForCasualLM"
|
||||
|
||||
def compute_logits(self, logits):
|
||||
logits[:, 0] += 1.0
|
||||
return logits
|
||||
|
||||
|
||||
def register():
|
||||
if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model_class(MyModelForCasualLM)
|
||||
ModelRegistry.register_pretrained_model(MyPretrainedModel)
|
15
test/plugins/setup.py
Normal file
15
test/plugins/setup.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="fastdeploy-plugins",
|
||||
version="0.1",
|
||||
packages=["fd_add_dummy_model"],
|
||||
entry_points={
|
||||
"fastdeploy.model_register_plugins": [
|
||||
"fd_add_dummy_model = fd_add_dummy_model:register",
|
||||
],
|
||||
# 'fastdeploy.model_runner_plugins': [
|
||||
# "model_runner = model_runner:get_runner"
|
||||
# ]
|
||||
},
|
||||
)
|
32
test/plugins/test_model_registry.py
Normal file
32
test/plugins/test_model_registry.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import unittest
|
||||
|
||||
from fastdeploy import ModelRegistry
|
||||
from fastdeploy.plugins import load_model_register_plugins
|
||||
|
||||
|
||||
class TestModelRegistryPlugins(unittest.TestCase):
|
||||
def test_plugin_registers_one_architecture(self):
|
||||
"""Test that loading plugins registers exactly one new architecture."""
|
||||
initial_archs = set(ModelRegistry.get_supported_archs())
|
||||
print("Supported architectures before loading plugins:", sorted(initial_archs))
|
||||
|
||||
# Load plugins
|
||||
load_model_register_plugins()
|
||||
|
||||
final_archs = set(ModelRegistry.get_supported_archs())
|
||||
print("Supported architectures after loading plugins:", sorted(final_archs))
|
||||
|
||||
added_archs = final_archs - initial_archs
|
||||
added_count = len(added_archs)
|
||||
|
||||
# verify
|
||||
self.assertEqual(
|
||||
added_count,
|
||||
1,
|
||||
f"Expected exactly 1 new architecture to be registered by plugins, "
|
||||
f"but {added_count} were added: {sorted(added_archs)}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user