add input_processor plugin (#3657)

* add input_processor plugin

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
This commit is contained in:
Yuanle Liu
2025-08-28 22:53:57 +08:00
committed by GitHub
parent 02b3644903
commit 4957908275
18 changed files with 232 additions and 146 deletions

View File

@@ -655,7 +655,6 @@ class EngineSevice:
time.sleep(0.005) time.sleep(0.005)
continue continue
for request_id, contents in results.items(): for request_id, contents in results.items():
llm_logger.info(f"Send results: {request_id} {contents}")
self.zmq_server.send_multipart(request_id, contents) self.zmq_server.send_multipart(request_id, contents)
except Exception as e: except Exception as e:

View File

@@ -73,6 +73,15 @@ class Request:
enable_thinking: Optional[bool] = True, enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict(), trace_carrier: dict = dict(),
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
image_start: int = 0,
video_start: int = 0,
audio_start: int = 0,
image_end: int = 0,
video_end: int = 0,
audio_end: int = 0,
prefill_start_index: int = 0,
prefill_end_index: int = 0,
num_computed_tokens: int = 0,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
@@ -117,7 +126,16 @@ class Request:
# token num # token num
self.block_tables = [] self.block_tables = []
self.output_token_ids = [] self.output_token_ids = []
self.num_computed_tokens = 0 self.num_computed_tokens = num_computed_tokens
self.prefill_start_index = prefill_start_index
self.prefill_end_index = prefill_end_index
self.image_start = image_start
self.video_start = video_start
self.audio_start = audio_start
self.image_end = image_end
self.video_end = video_end
self.audio_end = audio_end
# status # status
self.status = RequestStatus.WAITING self.status = RequestStatus.WAITING
self.task_type = RequestType.PREFILL self.task_type = RequestType.PREFILL
@@ -156,6 +174,15 @@ class Request:
enable_thinking=d.get("enable_thinking", True), enable_thinking=d.get("enable_thinking", True),
trace_carrier=d.get("trace_carrier", {}), trace_carrier=d.get("trace_carrier", {}),
chat_template=d.get("chat_template", None), chat_template=d.get("chat_template", None),
num_computed_tokens=d.get("num_computed_tokens", 0),
prefill_start_index=d.get("prefill_start_index", 0),
prefill_end_index=d.get("prefill_end_index", 0),
image_start=d.get("image_start", 0),
video_start=d.get("video_start", 0),
audio_start=d.get("audio_start", 0),
image_end=d.get("image_end", 0),
video_end=d.get("video_end", 0),
audio_end=d.get("audio_end", 0),
) )
@property @property
@@ -196,6 +223,15 @@ class Request:
"enable_thinking": self.enable_thinking, "enable_thinking": self.enable_thinking,
"trace_carrier": self.trace_carrier, "trace_carrier": self.trace_carrier,
"chat_template": self.chat_template, "chat_template": self.chat_template,
"num_computed_tokens": self.num_computed_tokens,
"prefill_start_index": self.prefill_start_index,
"prefill_end_index": self.prefill_end_index,
"image_start": self.image_start,
"video_start": self.video_start,
"audio_start": self.audio_start,
"image_end": self.image_end,
"video_end": self.video_end,
"audio_end": self.audio_end,
} }
add_params = [ add_params = [
"guided_json", "guided_json",

View File

@@ -129,6 +129,7 @@ class ResourceManagerV1(ResourceManager):
return can_schedule return can_schedule
def _get_num_new_tokens(self, request, token_budget): def _get_num_new_tokens(self, request, token_budget):
# TODO: set condition to new _get_num_new_tokens
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
@@ -136,10 +137,33 @@ class ResourceManagerV1(ResourceManager):
return num_new_tokens return num_new_tokens
inputs = request.multimodal_inputs inputs = request.multimodal_inputs
if (
inputs["image_feature_urls"] is not None
or inputs["video_feature_urls"] is not None
or inputs["audio_feature_urls"] is not None
):
pre_end_idx = request.num_computed_tokens
new_end_idx = pre_end_idx + num_new_tokens
# start
start_patch_idx = inputs["patch_idx"][pre_end_idx]
start_patch_map = inputs["patch_map"][start_patch_idx]
request.image_start = start_patch_map["image_num"]
request.video_start = start_patch_map["video_num"]
request.audio_start = start_patch_map["audio_num"]
# end
end_patch_idx = inputs["patch_idx"][new_end_idx]
end_patch_map = inputs["patch_map"][end_patch_idx]
end_modal_id = end_patch_map["modal_id"]
if end_modal_id > 0:
new_end_idx = end_patch_map["end_idx"] # 当前模态结束位置
num_new_tokens = new_end_idx - pre_end_idx
request.image_end = end_patch_map["image_num"]
request.video_end = end_patch_map["video_num"]
request.audio_end = end_patch_map["audio_num"]
elif inputs["images"] is not None and inputs["image_patch_id"] is not None and inputs["grid_thw"] is not None:
request.with_image = False request.with_image = False
# Compatible with scenarios without images and videos.
if inputs["images"] is None:
return num_new_tokens
input_ids_lst = request.prompt_token_ids + request.output_token_ids input_ids_lst = request.prompt_token_ids + request.output_token_ids
input_ids = paddle.to_tensor(input_ids_lst, dtype="int64") input_ids = paddle.to_tensor(input_ids_lst, dtype="int64")
@@ -188,7 +212,9 @@ class ResourceManagerV1(ResourceManager):
request.num_image_start = img_num_per_boundary[-1] request.num_image_start = img_num_per_boundary[-1]
else: else:
pre_boundary_idx = ( pre_boundary_idx = (
pre_boundary_idx if pre_end_idx == img_boundaries_idx[pre_boundary_idx] else pre_boundary_idx - 1 pre_boundary_idx
if pre_end_idx == img_boundaries_idx[pre_boundary_idx]
else pre_boundary_idx - 1
) )
request.num_image_start = img_num_per_boundary[pre_boundary_idx] request.num_image_start = img_num_per_boundary[pre_boundary_idx]
@@ -197,7 +223,9 @@ class ResourceManagerV1(ResourceManager):
request.num_image_end = img_num_per_boundary[-1] request.num_image_end = img_num_per_boundary[-1]
else: else:
new_boundary_idx = ( new_boundary_idx = (
new_boundary_idx if new_end_idx == img_boundaries_idx[new_boundary_idx] else new_boundary_idx - 1 new_boundary_idx
if new_end_idx == img_boundaries_idx[new_boundary_idx]
else new_boundary_idx - 1
) )
request.num_image_end = img_num_per_boundary[new_boundary_idx] request.num_image_end = img_num_per_boundary[new_boundary_idx]
@@ -205,6 +233,8 @@ class ResourceManagerV1(ResourceManager):
request.image_type_ids_end = np.sum(grid_thw[: request.num_image_end, 0]) request.image_type_ids_end = np.sum(grid_thw[: request.num_image_end, 0])
request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1)) request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1))
request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1)) request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1))
# Compatible with scenarios without images and videos.
return num_new_tokens return num_new_tokens
def exist_prefill(self, scheduled_reqs): def exist_prefill(self, scheduled_reqs):

View File

@@ -30,7 +30,6 @@ from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.chat_utils import load_chat_template from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.plugins.model_register import load_model_register_plugins
from fastdeploy.utils import ( from fastdeploy.utils import (
deprecated_kwargs_warning, deprecated_kwargs_warning,
llm_logger, llm_logger,
@@ -80,7 +79,6 @@ class LLM:
): ):
deprecated_kwargs_warning(**kwargs) deprecated_kwargs_warning(**kwargs)
load_model_register_plugins()
model = retrive_model_from_server(model, revision) model = retrive_model_from_server(model, revision)
tool_parser_plugin = kwargs.get("tool_parser_plugin") tool_parser_plugin = kwargs.get("tool_parser_plugin")
if tool_parser_plugin: if tool_parser_plugin:

View File

@@ -55,9 +55,6 @@ from fastdeploy.metrics.metrics import (
main_process_metrics, main_process_metrics,
) )
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
from fastdeploy.plugins.model_register import load_model_register_plugins
load_model_register_plugins()
from fastdeploy.utils import ( from fastdeploy.utils import (
FlexibleArgumentParser, FlexibleArgumentParser,
StatefulSemaphore, StatefulSemaphore,
@@ -532,7 +529,6 @@ def launch_controller_server():
def main(): def main():
"""main函数""" """main函数"""
load_model_register_plugins()
if args.local_data_parallel_id == 0: if args.local_data_parallel_id == 0:
if not load_engine(): if not load_engine():
return return

View File

@@ -79,6 +79,14 @@ class InputPreprocessor:
config = ModelConfig({"model": self.model_name_or_path}) config = ModelConfig({"model": self.model_name_or_path})
architectures = config.architectures[0] architectures = config.architectures[0]
try:
from fastdeploy.plugins.input_processor import load_input_processor_plugins
Processor = load_input_processor_plugins()
self.processor = Processor(
model_name_or_path=self.model_name_or_path,
)
except:
if not self.enable_mm: if not self.enable_mm:
if not ErnieArchitectures.contains_ernie_arch(architectures): if not ErnieArchitectures.contains_ernie_arch(architectures):
from fastdeploy.input.text_processor import DataProcessor from fastdeploy.input.text_processor import DataProcessor
@@ -98,7 +106,9 @@ class InputPreprocessor:
) )
else: else:
if ErnieArchitectures.contains_ernie_arch(architectures): if ErnieArchitectures.contains_ernie_arch(architectures):
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor from fastdeploy.input.ernie4_5_vl_processor import (
Ernie4_5_VLProcessor,
)
self.processor = Ernie4_5_VLProcessor( self.processor = Ernie4_5_VLProcessor(
model_name_or_path=self.model_name_or_path, model_name_or_path=self.model_name_or_path,

View File

@@ -23,7 +23,7 @@ import msgpack
import zmq import zmq
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.utils import llm_logger from fastdeploy.utils import zmq_client_logger
class ZmqClient: class ZmqClient:
@@ -71,7 +71,7 @@ class ZmqClient:
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router.setsockopt(zmq.SNDTIMEO, -1) self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind(f"ipc://{self.router_path}") self.router.bind(f"ipc://{self.router_path}")
llm_logger.info(f"router path: {self.router_path}") zmq_client_logger.info(f"router path: {self.router_path}")
def send_json(self, data): def send_json(self, data):
""" """
@@ -139,17 +139,17 @@ class ZmqClient:
else: else:
result = msgpack.packb([response.to_dict() for response in data]) result = msgpack.packb([response.to_dict() for response in data])
self.router.send_multipart([self.req_dict[req_id], b"", result]) self.router.send_multipart([self.req_dict[req_id], b"", result])
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
except zmq.ZMQError as e: except zmq.ZMQError as e:
llm_logger.error(f"[{req_id}] zmq error: {e}") zmq_client_logger.error(f"[{req_id}] zmq error: {e}")
self.req_dict[req_id] = -1 self.req_dict[req_id] = -1
except Exception as e: except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}") zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}")
if data[-1].finished: if data[-1].finished:
with self.mutex: with self.mutex:
self.req_dict.pop(req_id, None) self.req_dict.pop(req_id, None)
llm_logger.info(f"send_multipart finished, req_id: {req_id}") zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}")
def receive_json_once(self, block=False): def receive_json_once(self, block=False):
""" """
@@ -164,7 +164,7 @@ class ZmqClient:
return None, None return None, None
except Exception as e: except Exception as e:
self.close() self.close()
llm_logger.warning(f"{e}, {str(traceback.format_exc())}") zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
return str(e), None return str(e), None
def receive_pyobj_once(self, block=False): def receive_pyobj_once(self, block=False):
@@ -180,7 +180,7 @@ class ZmqClient:
return None, None return None, None
except Exception as e: except Exception as e:
self.close() self.close()
llm_logger.warning(f"{e}, {str(traceback.format_exc())}") zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
return str(e), None return str(e), None
def _clear_ipc(self, name): def _clear_ipc(self, name):
@@ -191,7 +191,7 @@ class ZmqClient:
try: try:
os.remove(name) os.remove(name)
except OSError as e: except OSError as e:
llm_logger.warning(f"Failed to remove IPC file {name} - {e}") zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}")
def close(self): def close(self):
""" """
@@ -201,7 +201,7 @@ class ZmqClient:
return return
self.running = False self.running = False
llm_logger.info("Closing ZMQ connection...") zmq_client_logger.info("Closing ZMQ connection...")
try: try:
if hasattr(self, "socket") and not self.socket.closed: if hasattr(self, "socket") and not self.socket.closed:
self.socket.close() self.socket.close()
@@ -215,7 +215,7 @@ class ZmqClient:
self._clear_ipc(self.file_name) self._clear_ipc(self.file_name)
self._clear_ipc(self.router_path) self._clear_ipc(self.router_path)
except Exception as e: except Exception as e:
llm_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}") zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}")
return return
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):

View File

@@ -79,6 +79,8 @@ class ForwardMeta:
forward_mode: ForwardMode = ForwardMode.MIXED forward_mode: ForwardMode = ForwardMode.MIXED
# Attention mask # Attention mask
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
# Attention mask offset
attn_mask_offsets: Optional[paddle.Tensor] = None
# Decoder batch id. Used by attention backend. # Decoder batch id. Used by attention backend.
decoder_batch_ids: Optional[paddle.Tensor] = None decoder_batch_ids: Optional[paddle.Tensor] = None
# Tile ID for each batch of the decoder. Used by attention backend. # Tile ID for each batch of the decoder. Used by attention backend.

View File

@@ -98,7 +98,9 @@ class AppendAttentionBackend(AttentionBackend):
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
@@ -140,6 +142,7 @@ class AppendAttentionBackend(AttentionBackend):
metadata.block_tables = forward_meta.block_tables metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask metadata.attn_mask = forward_meta.attn_mask
metadata.mask_offset = forward_meta.attn_mask_offsets
metadata.pre_caches_length = forward_meta.pre_caches_length metadata.pre_caches_length = forward_meta.pre_caches_length
( (
metadata.encoder_batch_ids, metadata.encoder_batch_ids,

View File

@@ -721,7 +721,8 @@ class RowParallelLinear(LinearBase):
add_bias=add_bias, add_bias=add_bias,
skip_quant=skip_quant, skip_quant=skip_quant,
) )
if add_bias:
assert with_bias, "with_bias must be True when add_bias is True."
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
self, self,
@@ -753,7 +754,8 @@ class RowParallelLinear(LinearBase):
if self.reduce_results and self.nranks > 1: if self.reduce_results and self.nranks > 1:
tensor_model_parallel_all_reduce(out, self.tp_group) tensor_model_parallel_all_reduce(out, self.tp_group)
if not self.fd_config.quant_config and self.add_bias:
out = paddle.add(out, self.bias)
return out return out

View File

@@ -54,8 +54,8 @@ def get_moe_scores(
scores, topk_values, topk_idx = noaux_tc( scores, topk_values, topk_idx = noaux_tc(
scores, scores,
scores_with_bias, scores_with_bias,
n_group, n_group if n_group > 0 else 1,
topk_group, topk_group if topk_group > 0 else 1,
top_k, top_k,
routed_scaling_factor, routed_scaling_factor,
) )

View File

@@ -21,6 +21,8 @@ from pathlib import Path
from paddleformers.transformers import PretrainedModel from paddleformers.transformers import PretrainedModel
from fastdeploy.plugins.model_register import load_model_register_plugins
from .model_base import ModelForCasualLM, ModelRegistry from .model_base import ModelForCasualLM, ModelRegistry
@@ -59,3 +61,5 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
auto_models_registry(os.path.dirname(__file__)) auto_models_registry(os.path.dirname(__file__))
load_model_register_plugins()

View File

@@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
""" """
from .input_processor import load_input_processor_plugins
from .model_register import load_model_register_plugins from .model_register import load_model_register_plugins
from .model_runner import load_model_runner_plugins from .model_runner import load_model_runner_plugins
__all__ = ["load_model_register_plugins", "load_model_runner_plugins"] __all__ = ["load_model_register_plugins", "load_model_runner_plugins", "load_input_processor_plugins"]

View File

@@ -0,0 +1,27 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.plugins.utils import load_plugins_by_group
# make sure one process only loads plugins once
PLUGINS_GROUP = "fastdeploy.input_processor_plugins"
def load_input_processor_plugins():
"""load_input_processor_plugins"""
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
return next(iter(plugins.values()))()

View File

@@ -56,9 +56,6 @@ class RolloutModel(nn.Layer):
def _init_model(self) -> nn.Layer: def _init_model(self) -> nn.Layer:
"""Load model from loader based on config.""" """Load model from loader based on config."""
context = paddle.LazyGuard() context = paddle.LazyGuard()
from fastdeploy.plugins.model_register import load_model_register_plugins
load_model_register_plugins()
architectures = f"{self.fd_config.model_config.architectures[0]}RL" architectures = f"{self.fd_config.model_config.architectures[0]}RL"
with context: with context:
model_cls = ModelRegistry.get_class(architectures) model_cls = ModelRegistry.get_class(architectures)

View File

@@ -769,3 +769,4 @@ scheduler_logger = get_logger("scheduler", "scheduler.log")
api_server_logger = get_logger("api_server", "api_server.log") api_server_logger = get_logger("api_server", "api_server.log")
console_logger = get_logger("console", "console.log", print_to_console=True) console_logger = get_logger("console", "console.log", print_to_console=True)
spec_logger = get_logger("speculate", "speculate.log") spec_logger = get_logger("speculate", "speculate.log")
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")

View File

@@ -105,7 +105,8 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
def update_fd_config_for_mm(fd_config: FDConfig) -> None: def update_fd_config_for_mm(fd_config: FDConfig) -> None:
if fd_config.model_config.enable_mm: architectures = fd_config.model_config.architectures
if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures):
tokenizer = Ernie4_5Tokenizer.from_pretrained( tokenizer = Ernie4_5Tokenizer.from_pretrained(
fd_config.model_config.model, fd_config.model_config.model,
model_max_length=fd_config.parallel_config.max_model_len, model_max_length=fd_config.parallel_config.max_model_len,
@@ -771,7 +772,4 @@ def run_worker_proc() -> None:
if __name__ == "__main__": if __name__ == "__main__":
from fastdeploy.plugins.model_register import load_model_register_plugins
load_model_register_plugins()
run_worker_proc() run_worker_proc()

View File

@@ -14,33 +14,15 @@
import unittest import unittest
from fastdeploy import ModelRegistry
from fastdeploy.plugins import load_model_register_plugins from fastdeploy.plugins import load_model_register_plugins
class TestModelRegistryPlugins(unittest.TestCase): class TestModelRegistryPlugins(unittest.TestCase):
def test_plugin_registers_one_architecture(self): def test_plugin_registers_one_architecture(self):
"""Test that loading plugins registers exactly one new architecture.""" """Test that loading plugins registers exactly one new architecture."""
initial_archs = set(ModelRegistry.get_supported_archs())
print("Supported architectures before loading plugins:", sorted(initial_archs))
# Load plugins # Load plugins
load_model_register_plugins() load_model_register_plugins()
final_archs = set(ModelRegistry.get_supported_archs())
print("Supported architectures after loading plugins:", sorted(final_archs))
added_archs = final_archs - initial_archs
added_count = len(added_archs)
# verify
self.assertEqual(
added_count,
1,
f"Expected exactly 1 new architecture to be registered by plugins, "
f"but {added_count} were added: {sorted(added_archs)}",
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()