diff --git a/docs/features/plugins.md b/docs/features/plugins.md index f05f248f5..0fe97ef7b 100644 --- a/docs/features/plugins.md +++ b/docs/features/plugins.md @@ -20,12 +20,23 @@ Assuming you have a custom model class `MyModelForCasualLM` and a pretrained cla # File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py from fastdeploy.model_registry import ModelRegistry from my_custom_model import MyModelForCasualLM, MyPretrainedModel +from fastdeploy.config import ErnieArchitectures def register(): if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs(): + if MyModelForCasualLM.name().startswith("Ernie"): + ErnieArchitectures.register_ernie_model_arch(MyModelForCasualLM) ModelRegistry.register_model_class(MyModelForCasualLM) ModelRegistry.register_pretrained_model(MyPretrainedModel) ``` +Assuming you have a custom model_runner class `MyModelRunner`, you can write the following registration function: +```python +# File: fd_add_dummy_model_runner/__init__.py +from .my_model_runner import MyModelRunner + +def get_runner(): + return MyModelRunner +``` #### 2. Register Plugin in `setup.py` @@ -36,11 +47,14 @@ from setuptools import setup setup( name="fastdeploy-plugins", version="0.1", - packages=["fd_add_dummy_model"], + packages=["fd_add_dummy_model", "fd_add_dummy_model_runner"], entry_points={ "fastdeploy.model_register_plugins": [ "fd_add_dummy_model = fd_add_dummy_model:register", ], + "fastdeploy.model_runner_plugins": [ + "model_runner = fd_add_dummy_model:get_runner" + ], }, ) ``` diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 6e27196f6..410b6e686 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -63,6 +63,11 @@ class ErnieArchitectures: "Ernie4_5_VLMoeForConditionalGeneration", } + @classmethod + def register_ernie_model_arch(cls, model_class): + if model_class.name().startswith("Ernie") and model_class.name() not in cls.ARCHITECTURES: + cls.ARCHITECTURES.add(model_class.name()) + @classmethod def contains_ernie_arch(cls, architectures): """Check if any ERNIE architecture is present in the given architectures.""" diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index a0963c5eb..c91763257 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -28,6 +28,7 @@ from tqdm import tqdm from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.plugins.model_register import load_model_register_plugins from fastdeploy.utils import ( deprecated_kwargs_warning, llm_logger, @@ -76,6 +77,7 @@ class LLM: ): deprecated_kwargs_warning(**kwargs) + load_model_register_plugins() model = retrive_model_from_server(model, revision) engine_args = EngineArgs( model=model, diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 50dabb78e..8bcc74a31 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -46,6 +46,7 @@ from fastdeploy.metrics.metrics import ( main_process_metrics, ) from fastdeploy.metrics.trace_util import inject_to_metadata, instrument +from fastdeploy.plugins.model_register import load_model_register_plugins from fastdeploy.utils import ( FlexibleArgumentParser, api_server_logger, @@ -393,6 +394,7 @@ def launch_controller_server(): def main(): """main函数""" + load_model_register_plugins() if load_engine() is None: return diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index abefb0735..120be9ce8 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -1,7 +1,7 @@ """ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License" +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index e6ae92b3f..f1bc434bc 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -48,6 +48,7 @@ class Attention(nn.Layer): linear_shift: paddle.Tensor = None, linear_smooth: paddle.Tensor = None, use_neox_rotary_style: bool = False, + use_qk_norm: bool = False, ) -> None: """ Initializes `LMLayer` with the given parameters. diff --git a/fastdeploy/model_executor/models/model_base.py b/fastdeploy/model_executor/models/model_base.py index 96986d263..06f0d0705 100644 --- a/fastdeploy/model_executor/models/model_base.py +++ b/fastdeploy/model_executor/models/model_base.py @@ -64,9 +64,9 @@ class ModelRegistry: @classmethod def get_supported_archs(cls): - assert len(cls._arch_to_model_cls) == len( - cls._arch_to_model_cls - ), "model class / pretrained model registry num is not same" + assert len(cls._arch_to_model_cls) >= len( + cls._arch_to_pretrained_model_cls + ), "model class num is more than pretrained model registry num" return [key for key in cls._arch_to_model_cls.keys()] diff --git a/fastdeploy/plugins/model_runner/__init__.py b/fastdeploy/plugins/model_runner/__init__.py index 2d3f426e7..8897abfbc 100644 --- a/fastdeploy/plugins/model_runner/__init__.py +++ b/fastdeploy/plugins/model_runner/__init__.py @@ -28,5 +28,5 @@ def load_model_runner_plugins(): plugins_loaded = True plugins = load_plugins_by_group(group=PLUGINS_GROUP) - assert len(plugins) == 1, "Only one plugin is allowed to be loaded." - return next(iter(plugins.values())) + assert len(plugins) <= 1, "Most one plugin is allowed to be loaded." + return next(iter(plugins.values()))() diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index 11701c0e0..72fead1cd 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -56,6 +56,9 @@ class RolloutModel(nn.Layer): def _init_model(self) -> nn.Layer: """Load model from loader based on config.""" context = paddle.LazyGuard() + from fastdeploy.plugins.model_register import load_model_register_plugins + + load_model_register_plugins() architectures = f"{self.fd_config.model_config.architectures[0]}RL" with context: model_cls = ModelRegistry.get_class(architectures) diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 084b4f0f2..242559413 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -26,13 +26,19 @@ from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request from fastdeploy.platforms import current_platform +from fastdeploy.plugins.model_runner import load_model_runner_plugins from fastdeploy.utils import get_logger -from fastdeploy.worker.gpu_model_runner import GPUModelRunner +from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.worker_base import WorkerBase logger = get_logger("gpu_worker", "gpu_worker.log") +try: + ModelRunner = load_model_runner_plugins() +except: + from fastdeploy.worker.gpu_model_runner import GPUModelRunner as ModelRunner + class GpuWorker(WorkerBase): def __init__( @@ -70,7 +76,7 @@ class GpuWorker(WorkerBase): raise RuntimeError(f"Not support device type: {self.device_config.device}") # Construct model runner - self.model_runner: GPUModelRunner = GPUModelRunner( + self.model_runner: ModelRunnerBase = ModelRunner( fd_config=self.fd_config, device=self.device, device_id=self.device_ids[self.local_rank % self.max_chips_per_node], diff --git a/test/plugins/fd_add_dummy_model/__init__.py b/test/plugins/fd_add_dummy_model/__init__.py index 2f8ae6a25..1c7dba0cc 100644 --- a/test/plugins/fd_add_dummy_model/__init__.py +++ b/test/plugins/fd_add_dummy_model/__init__.py @@ -1,6 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddleformers.transformers import PretrainedModel from fastdeploy import ModelRegistry +from fastdeploy.config import ErnieArchitectures from fastdeploy.model_executor.models.model_base import ModelForCasualLM @@ -31,5 +46,7 @@ class MyModelForCasualLM(ModelForCasualLM): def register(): if "MyModelForCasualLM" not in ModelRegistry.get_supported_archs(): + if MyModelForCasualLM.name().startswith("Ernie"): + ErnieArchitectures.register_ernie_model_arch(MyModelForCasualLM) ModelRegistry.register_model_class(MyModelForCasualLM) ModelRegistry.register_pretrained_model(MyPretrainedModel) diff --git a/test/plugins/fd_add_dummy_model_runner/__init__.py b/test/plugins/fd_add_dummy_model_runner/__init__.py new file mode 100644 index 000000000..b8fc023a3 --- /dev/null +++ b/test/plugins/fd_add_dummy_model_runner/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class MyModelRunner: + def __init__(self, rank=0) -> None: + super().__init__() + self.rank = rank + + def get_rank(self): + return self.rank + + +def get_runner(): + return MyModelRunner diff --git a/test/plugins/setup.py b/test/plugins/setup.py index 5a570f2b0..92c953d61 100644 --- a/test/plugins/setup.py +++ b/test/plugins/setup.py @@ -1,15 +1,27 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from setuptools import setup setup( name="fastdeploy-plugins", version="0.1", - packages=["fd_add_dummy_model"], + packages=["fd_add_dummy_model", "fd_add_dummy_model_runner"], entry_points={ "fastdeploy.model_register_plugins": [ "fd_add_dummy_model = fd_add_dummy_model:register", ], - # 'fastdeploy.model_runner_plugins': [ - # "model_runner = model_runner:get_runner" - # ] + "fastdeploy.model_runner_plugins": ["fd_add_dummy_model_runner = fd_add_dummy_model_runner:get_runner"], }, ) diff --git a/test/plugins/test_model_registry.py b/test/plugins/test_model_registry.py index 4b00afec4..f58399537 100644 --- a/test/plugins/test_model_registry.py +++ b/test/plugins/test_model_registry.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from fastdeploy import ModelRegistry diff --git a/test/plugins/test_model_runner_register.py b/test/plugins/test_model_runner_register.py new file mode 100644 index 000000000..85110ba62 --- /dev/null +++ b/test/plugins/test_model_runner_register.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from fastdeploy.plugins import load_model_runner_plugins + + +class TestModelRunnerRegistryPlugins(unittest.TestCase): + def test_model_runner_callable(self): + runner_class = load_model_runner_plugins() + device_id = 1 + + # create runner + runner = runner_class(device_id) + + # test func + res = runner.get_rank() + + self.assertEqual(res, device_id) + + +if __name__ == "__main__": + unittest.main()