From b71cbb466d4fc496b11a56df16de5692a633c1aa Mon Sep 17 00:00:00 2001 From: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com> Date: Fri, 1 Aug 2025 20:01:18 +0800 Subject: [PATCH] [Feature] remove dependency on enable_mm and refine multimodal's code (#3014) * remove dependency on enable_mm * fix codestyle check error * fix codestyle check error * update docs * resolve conflicts on model config * fix unit test error * fix code style check error --------- Co-authored-by: shige <1021937542@qq.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> --- docs/get_started/ernie-4.5-vl.md | 1 - docs/get_started/quick_start_vl.md | 3 +- docs/offline_inference.md | 4 +- docs/parameters.md | 2 +- docs/zh/get_started/ernie-4.5-vl.md | 1 - docs/zh/get_started/quick_start_vl.md | 3 +- docs/zh/offline_inference.md | 4 +- docs/zh/parameters.md | 2 +- fastdeploy/engine/args_utils.py | 6 +-- fastdeploy/engine/config.py | 15 ++++-- fastdeploy/entrypoints/chat_utils.py | 4 +- fastdeploy/entrypoints/engine_client.py | 16 ++++-- fastdeploy/entrypoints/llm.py | 10 ++-- fastdeploy/entrypoints/openai/api_server.py | 3 +- fastdeploy/input/preprocess.py | 4 +- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 2 + fastdeploy/{input => }/multimodal/__init__.py | 0 fastdeploy/{input => }/multimodal/audio.py | 0 fastdeploy/{input => }/multimodal/base.py | 0 fastdeploy/{input => }/multimodal/image.py | 0 fastdeploy/multimodal/registry.py | 49 +++++++++++++++++++ fastdeploy/{input => }/multimodal/utils.py | 0 fastdeploy/{input => }/multimodal/video.py | 0 fastdeploy/utils.py | 18 +++++++ 24 files changed, 118 insertions(+), 29 deletions(-) rename fastdeploy/{input => }/multimodal/__init__.py (100%) rename fastdeploy/{input => }/multimodal/audio.py (100%) rename fastdeploy/{input => }/multimodal/base.py (100%) rename fastdeploy/{input => }/multimodal/image.py (100%) create mode 100644 fastdeploy/multimodal/registry.py rename fastdeploy/{input => }/multimodal/utils.py (100%) rename fastdeploy/{input => }/multimodal/video.py (100%) diff --git a/docs/get_started/ernie-4.5-vl.md b/docs/get_started/ernie-4.5-vl.md index 71b0626ae..1092ed19f 100644 --- a/docs/get_started/ernie-4.5-vl.md +++ b/docs/get_started/ernie-4.5-vl.md @@ -31,7 +31,6 @@ python -m fastdeploy.entrypoints.openai.api_server \ --quantization wint4 \ --max-model-len 32768 \ --max-num-seqs 32 \ - --enable-mm \ --mm-processor-kwargs '{"video_max_frames": 30}' \ --limit-mm-per-prompt '{"image": 10, "video": 3}' \ --reasoning-parser ernie-45-vl diff --git a/docs/get_started/quick_start_vl.md b/docs/get_started/quick_start_vl.md index 83b1b97d7..6e0b9a780 100644 --- a/docs/get_started/quick_start_vl.md +++ b/docs/get_started/quick_start_vl.md @@ -26,8 +26,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ --engine-worker-queue-port 8182 \ --max-model-len 32768 \ --max-num-seqs 32 \ - --reasoning-parser ernie-45-vl \ - --enable-mm + --reasoning-parser ernie-45-vl ``` > 💡 Note: In the path specified by ```--model```, if the subdirectory corresponding to the path does not exist in the current directory, it will try to query whether AIStudio has a preset model based on the specified model name (such as ```baidu/ERNIE-4.5-0.3B-Base-Paddle```). If it exists, it will automatically start downloading. The default download path is: ```~/xx```. For instructions and configuration on automatic model download, see [Model Download](../supported_models.md). diff --git a/docs/offline_inference.md b/docs/offline_inference.md index 45a77615a..2da2286b8 100644 --- a/docs/offline_inference.md +++ b/docs/offline_inference.md @@ -39,7 +39,7 @@ Documentation for `SamplingParams`, `LLM.generate`, `LLM.chat`, and output struc ```python from fastdeploy.entrypoints.llm import LLM # 加载模型 -llm = LLM(model="baidu/ERNIE-4.5-VL-28B-A3B-Paddle", tensor_parallel_size=1, max_model_len=32768, enable_mm=True, limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") +llm = LLM(model="baidu/ERNIE-4.5-VL-28B-A3B-Paddle", tensor_parallel_size=1, max_model_len=32768, limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") outputs = llm.chat( messages=[ @@ -127,7 +127,7 @@ for message in messages: }) sampling_params = SamplingParams(temperature=0.1, max_tokens=6400) -llm = LLM(model=PATH, tensor_parallel_size=1, max_model_len=32768, enable_mm=True, limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") +llm = LLM(model=PATH, tensor_parallel_size=1, max_model_len=32768, limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") outputs = llm.generate(prompts={ "prompt": prompt, "multimodal_data": { diff --git a/docs/parameters.md b/docs/parameters.md index c52fc9ac6..245eec83f 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -19,7 +19,7 @@ When using FastDeploy to deploy models (including offline inference and service | ```tokenizer``` | `str` | Tokenizer name or path, defaults to model path | | ```use_warmup``` | `int` | Whether to perform warmup at startup, will automatically generate maximum length data for warmup, enabled by default when automatically calculating KV Cache | | ```limit_mm_per_prompt``` | `dict[str]` | Limit the amount of multimodal data per prompt, e.g.: {"image": 10, "video": 3}, default: 1 for all | -| ```enable_mm``` | `bool` | Whether to support multimodal data (for multimodal models only), default: False | +| ```enable_mm``` | `bool` | __[DEPRECATED]__ Whether to support multimodal data (for multimodal models only), default: False | | ```quantization``` | `str` | Model quantization strategy, when loading BF16 CKPT, specifying wint4 or wint8 supports lossless online 4bit/8bit quantization | | ```gpu_memory_utilization``` | `float` | GPU memory utilization, default: 0.9 | | ```num_gpu_blocks_override``` | `int` | Preallocated KVCache blocks, this parameter can be automatically calculated by FastDeploy based on memory situation, no need for user configuration, default: None | diff --git a/docs/zh/get_started/ernie-4.5-vl.md b/docs/zh/get_started/ernie-4.5-vl.md index 3922c899f..3f12904c5 100644 --- a/docs/zh/get_started/ernie-4.5-vl.md +++ b/docs/zh/get_started/ernie-4.5-vl.md @@ -31,7 +31,6 @@ python -m fastdeploy.entrypoints.openai.api_server \ --quantization wint4 \ --max-model-len 32768 \ --max-num-seqs 32 \ - --enable-mm \ --mm-processor-kwargs '{"video_max_frames": 30}' \ --limit-mm-per-prompt '{"image": 10, "video": 3}' \ --reasoning-parser ernie-45-vl diff --git a/docs/zh/get_started/quick_start_vl.md b/docs/zh/get_started/quick_start_vl.md index 0f4c88cc1..b3b153817 100644 --- a/docs/zh/get_started/quick_start_vl.md +++ b/docs/zh/get_started/quick_start_vl.md @@ -26,8 +26,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ --engine-worker-queue-port 8182 \ --max-model-len 32768 \ --max-num-seqs 32 \ - --reasoning-parser ernie-45-vl \ - --enable-mm + --reasoning-parser ernie-45-vl ``` >💡 注意:在 ```--model``` 指定的路径中,若当前目录下不存在该路径对应的子目录,则会尝试根据指定的模型名称(如 ```baidu/ERNIE-4.5-0.3B-Base-Paddle```)查询AIStudio是否存在预置模型,若存在,则自动启动下载。默认的下载路径为:```~/xx```。关于模型自动下载的说明和配置参阅[模型下载](../supported_models.md)。 diff --git a/docs/zh/offline_inference.md b/docs/zh/offline_inference.md index 015fc7b72..855116484 100644 --- a/docs/zh/offline_inference.md +++ b/docs/zh/offline_inference.md @@ -39,7 +39,7 @@ for output in outputs: ```python from fastdeploy.entrypoints.llm import LLM # 加载模型 -llm = LLM(model="baidu/ERNIE-4.5-VL-28B-A3B-Paddle", tensor_parallel_size=1, max_model_len=32768, enable_mm=True, limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") +llm = LLM(model="baidu/ERNIE-4.5-VL-28B-A3B-Paddle", tensor_parallel_size=1, max_model_len=32768, limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") outputs = llm.chat( messages=[ @@ -127,7 +127,7 @@ for message in messages: }) sampling_params = SamplingParams(temperature=0.1, max_tokens=6400) -llm = LLM(model=PATH, tensor_parallel_size=1, max_model_len=32768, enable_mm=True, limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") +llm = LLM(model=PATH, tensor_parallel_size=1, max_model_len=32768, limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") outputs = llm.generate(prompts={ "prompt": prompt, "multimodal_data": { diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index fbf57a971..244d78ab7 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -17,7 +17,7 @@ | ```tokenizer``` | `str` | tokenizer 名或路径,默认为模型路径 | | ```use_warmup``` | `int` | 是否在启动时进行warmup,会自动生成极限长度数据进行warmup,默认自动计算KV Cache时会使用 | | ```limit_mm_per_prompt``` | `dict[str]` | 限制每个prompt中多模态数据的数量,如:{"image": 10, "video": 3},默认都为1 | -| ```enable_mm``` | `bool` | 是否支持多模态数据(仅针对多模模型),默认False | +| ```enable_mm``` | `bool` | __[已废弃]__ 是否支持多模态数据(仅针对多模模型),默认False | | ```quantization``` | `str` | 模型量化策略,当在加载BF16 CKPT时,指定wint4或wint8时,支持无损在线4bit/8bit量化 | | ```gpu_memory_utilization``` | `float` | GPU显存利用率,默认0.9 | | ```num_gpu_blocks_override``` | `int` | 预分配KVCache块数,此参数可由FastDeploy自动根据显存情况计算,无需用户配置,默认为None | diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 4a2414304..d8c57ae45 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -31,7 +31,7 @@ from fastdeploy.config import ( ) from fastdeploy.engine.config import Config from fastdeploy.scheduler.config import SchedulerConfig -from fastdeploy.utils import FlexibleArgumentParser +from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser def nullable_str(x: str) -> Optional[str]: @@ -409,7 +409,7 @@ class EngineArgs: ) model_group.add_argument( "--enable-mm", - action="store_true", + action=DeprecatedOptionWarning, default=EngineArgs.enable_mm, help="Flag to enable multi-modal model.", ) @@ -902,7 +902,7 @@ class EngineArgs: engine_worker_queue_port=self.engine_worker_queue_port, limit_mm_per_prompt=self.limit_mm_per_prompt, mm_processor_kwargs=self.mm_processor_kwargs, - enable_mm=self.enable_mm, + # enable_mm=self.enable_mm, reasoning_parser=self.reasoning_parser, splitwise_role=self.splitwise_role, innode_prefill_ports=self.innode_prefill_ports, diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index dc7993500..035cea96c 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -25,6 +25,7 @@ from fastdeploy.config import ( ModelConfig, ParallelConfig, ) +from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig from fastdeploy.utils import ceil_div, get_host_ip, is_port_available, llm_logger @@ -78,7 +79,7 @@ class Config: engine_worker_queue_port: int = 8002, limit_mm_per_prompt: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - enable_mm: bool = False, + # enable_mm: bool = False, splitwise_role: str = "mixed", innode_prefill_ports: Optional[List[int]] = None, max_num_partial_prefills: int = 1, @@ -156,7 +157,7 @@ class Config: self.max_num_seqs = max_num_seqs self.limit_mm_per_prompt = limit_mm_per_prompt self.mm_processor_kwargs = mm_processor_kwargs - self.enable_mm = enable_mm + # self.enable_mm = enable_mm self.speculative_config = speculative_config self.use_warmup = use_warmup self.splitwise_role = splitwise_role @@ -174,11 +175,19 @@ class Config: assert self.splitwise_role in ["mixed", "prefill", "decode"] + import fastdeploy.model_executor.models # noqa: F401 + + architectures = self.model_config.architectures[0] + if MultimodalRegistry.contains_model(architectures): + self.enable_mm = True + else: + self.enable_mm = False + # TODO self.max_prefill_batch = 3 if current_platform.is_xpu(): self.max_prefill_batch = 1 - if enable_mm: + if self.enable_mm: self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化 # TODO(@wufeisheng): TP and EP need to be supported simultaneously. diff --git a/fastdeploy/entrypoints/chat_utils.py b/fastdeploy/entrypoints/chat_utils.py index 4f7357e11..4fe6f9db6 100644 --- a/fastdeploy/entrypoints/chat_utils.py +++ b/fastdeploy/entrypoints/chat_utils.py @@ -27,8 +27,8 @@ from openai.types.chat import ( ) from typing_extensions import Required, TypeAlias, TypedDict -from fastdeploy.input.multimodal.image import ImageMediaIO -from fastdeploy.input.multimodal.video import VideoMediaIO +from fastdeploy.multimodal.image import ImageMediaIO +from fastdeploy.multimodal.video import VideoMediaIO class VideoURL(TypedDict, total=False): diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 9be9eccb4..09d6e8ff9 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -19,9 +19,11 @@ import uuid import numpy as np +from fastdeploy.engine.config import ModelConfig from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.metrics.work_metrics import work_process_metrics +from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform from fastdeploy.utils import EngineError, api_server_logger @@ -33,26 +35,34 @@ class EngineClient: def __init__( self, + model_name_or_path, tokenizer, max_model_len, tensor_parallel_size, pid, limit_mm_per_prompt, mm_processor_kwargs, - enable_mm=False, + # enable_mm=False, reasoning_parser=None, data_parallel_size=1, enable_logprob=False, ): + import fastdeploy.model_executor.models # noqa: F401 + + architectures = ModelConfig({"model": model_name_or_path}).architectures[0] + if MultimodalRegistry.contains_model(architectures): + self.enable_mm = True + else: + self.enable_mm = False + input_processor = InputPreprocessor( tokenizer, reasoning_parser, limit_mm_per_prompt, mm_processor_kwargs, - enable_mm, + self.enable_mm, ) self.enable_logprob = enable_logprob - self.enable_mm = enable_mm self.reasoning_parser = reasoning_parser self.data_processor = input_processor.create_processor() self.max_model_len = max_model_len diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 8365c6985..66833428b 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -28,9 +28,11 @@ from tqdm import tqdm from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.sampling_params import SamplingParams - -# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam -from fastdeploy.utils import llm_logger, retrive_model_from_server +from fastdeploy.utils import ( + deprecated_kwargs_warning, + llm_logger, + retrive_model_from_server, +) from fastdeploy.worker.output import Logprob, LogprobsLists root_logger = logging.getLogger() @@ -72,6 +74,8 @@ class LLM: enable_logprob: Optional[bool] = False, **kwargs, ): + deprecated_kwargs_warning(**kwargs) + model = retrive_model_from_server(model, revision) engine_args = EngineArgs( model=model, diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 5161d2d25..50dabb78e 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -105,13 +105,14 @@ async def lifespan(app: FastAPI): pid = os.getpid() api_server_logger.info(f"{pid}") engine_client = EngineClient( + args.model, args.tokenizer, args.max_model_len, args.tensor_parallel_size, pid, args.limit_mm_per_prompt, args.mm_processor_kwargs, - args.enable_mm, + # args.enable_mm, args.reasoning_parser, args.data_parallel_size, args.enable_logprob, diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index 8edd4eb4b..abefb0735 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -87,8 +87,8 @@ class InputPreprocessor: reasoning_parser_obj=reasoning_parser_obj, ) else: - if not architectures.startswith("Ernie4_5_VLMoeForConditionalGeneration"): - raise ValueError(f"Model {self.model_name_or_path} is not a valid Ernie4_5_VLMoe model.") + if not ErnieArchitectures.contains_ernie_arch(architectures): + raise ValueError(f"Model {self.model_name_or_path} is not a valid Ernie4_5_VL model.") else: from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 6016b06fd..ec6b21b5a 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -42,6 +42,7 @@ from fastdeploy.model_executor.models.ernie4_5_moe import ( Ernie4_5_MLP, ) from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform if current_platform.is_cuda(): @@ -487,6 +488,7 @@ class Ernie4_5_VLModel(nn.Layer): return out +@MultimodalRegistry.register_model() class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): """ Ernie4_5_VLMoeForConditionalGeneration diff --git a/fastdeploy/input/multimodal/__init__.py b/fastdeploy/multimodal/__init__.py similarity index 100% rename from fastdeploy/input/multimodal/__init__.py rename to fastdeploy/multimodal/__init__.py diff --git a/fastdeploy/input/multimodal/audio.py b/fastdeploy/multimodal/audio.py similarity index 100% rename from fastdeploy/input/multimodal/audio.py rename to fastdeploy/multimodal/audio.py diff --git a/fastdeploy/input/multimodal/base.py b/fastdeploy/multimodal/base.py similarity index 100% rename from fastdeploy/input/multimodal/base.py rename to fastdeploy/multimodal/base.py diff --git a/fastdeploy/input/multimodal/image.py b/fastdeploy/multimodal/image.py similarity index 100% rename from fastdeploy/input/multimodal/image.py rename to fastdeploy/multimodal/image.py diff --git a/fastdeploy/multimodal/registry.py b/fastdeploy/multimodal/registry.py new file mode 100644 index 000000000..402e8d204 --- /dev/null +++ b/fastdeploy/multimodal/registry.py @@ -0,0 +1,49 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Callable + + +class MultimodalRegistry: + """ + A registry for multimodal models + """ + + mm_models: set[str] = set() + + @classmethod + def register_model(cls, name: str = "") -> Callable: + """ + Register model with the given name, class name is used if name is not provided. + """ + + def _register(model): + nonlocal name + if len(name) == 0: + name = model.__name__ + if name in cls.mm_models: + raise ValueError(f"multimodal model {name} is already registered") + cls.mm_models.add(name) + return model + + return _register + + @classmethod + def contains_model(cls, name: str) -> bool: + """ + Check if the given name exists in registry. + """ + return name in cls.mm_models diff --git a/fastdeploy/input/multimodal/utils.py b/fastdeploy/multimodal/utils.py similarity index 100% rename from fastdeploy/input/multimodal/utils.py rename to fastdeploy/multimodal/utils.py diff --git a/fastdeploy/input/multimodal/video.py b/fastdeploy/multimodal/video.py similarity index 100% rename from fastdeploy/input/multimodal/video.py rename to fastdeploy/multimodal/video.py diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 9ea25000c..db2b560ca 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -596,6 +596,24 @@ def version(): return content +class DeprecatedOptionWarning(argparse.Action): + def __init__(self, option_strings, dest, **kwargs): + super().__init__(option_strings, dest, nargs=0, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + console_logger.warning(f"Deprecated option is detected: {option_string}, which may be removed later") + setattr(namespace, self.dest, True) + + +DEPRECATED_ARGS = ["enable_mm"] + + +def deprecated_kwargs_warning(**kwargs): + for arg in DEPRECATED_ARGS: + if arg in kwargs: + console_logger.warning(f"Deprecated argument is detected: {arg}, which may be removed later") + + llm_logger = get_logger("fastdeploy", "fastdeploy.log") data_processor_logger = get_logger("data_processor", "data_processor.log") scheduler_logger = get_logger("scheduler", "scheduler.log")