From 49579082750461e25ea6639c6a25795832b5e7ce Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 28 Aug 2025 22:53:57 +0800 Subject: [PATCH] add input_processor plugin (#3657) * add input_processor plugin * update * update * update * update * update * update * update * update * update * update * update --- fastdeploy/engine/common_engine.py | 1 - fastdeploy/engine/request.py | 38 ++++- .../engine/sched/resource_manager_v1.py | 154 +++++++++++------- fastdeploy/entrypoints/llm.py | 2 - fastdeploy/entrypoints/openai/api_server.py | 4 - fastdeploy/input/preprocess.py | 78 +++++---- fastdeploy/inter_communicator/zmq_client.py | 22 +-- fastdeploy/model_executor/forward_meta.py | 2 + .../layers/attention/append_attn_backend.py | 5 +- fastdeploy/model_executor/layers/linear.py | 6 +- fastdeploy/model_executor/layers/moe/ep.py | 4 +- fastdeploy/model_executor/models/__init__.py | 4 + fastdeploy/plugins/__init__.py | 3 +- .../plugins/input_processor/__init__.py | 27 +++ fastdeploy/rl/rollout_model.py | 3 - fastdeploy/utils.py | 1 + fastdeploy/worker/worker_process.py | 6 +- tests/plugins/test_model_registry.py | 18 -- 18 files changed, 232 insertions(+), 146 deletions(-) create mode 100644 fastdeploy/plugins/input_processor/__init__.py diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index aa1ebe111..6081d1e06 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -655,7 +655,6 @@ class EngineSevice: time.sleep(0.005) continue for request_id, contents in results.items(): - llm_logger.info(f"Send results: {request_id} {contents}") self.zmq_server.send_multipart(request_id, contents) except Exception as e: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 67c0caa08..04a2276af 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -73,6 +73,15 @@ class Request: enable_thinking: Optional[bool] = True, trace_carrier: dict = dict(), chat_template: Optional[str] = None, + image_start: int = 0, + video_start: int = 0, + audio_start: int = 0, + image_end: int = 0, + video_end: int = 0, + audio_end: int = 0, + prefill_start_index: int = 0, + prefill_end_index: int = 0, + num_computed_tokens: int = 0, ) -> None: self.request_id = request_id self.prompt = prompt @@ -117,7 +126,16 @@ class Request: # token num self.block_tables = [] self.output_token_ids = [] - self.num_computed_tokens = 0 + self.num_computed_tokens = num_computed_tokens + self.prefill_start_index = prefill_start_index + self.prefill_end_index = prefill_end_index + self.image_start = image_start + self.video_start = video_start + self.audio_start = audio_start + + self.image_end = image_end + self.video_end = video_end + self.audio_end = audio_end # status self.status = RequestStatus.WAITING self.task_type = RequestType.PREFILL @@ -156,6 +174,15 @@ class Request: enable_thinking=d.get("enable_thinking", True), trace_carrier=d.get("trace_carrier", {}), chat_template=d.get("chat_template", None), + num_computed_tokens=d.get("num_computed_tokens", 0), + prefill_start_index=d.get("prefill_start_index", 0), + prefill_end_index=d.get("prefill_end_index", 0), + image_start=d.get("image_start", 0), + video_start=d.get("video_start", 0), + audio_start=d.get("audio_start", 0), + image_end=d.get("image_end", 0), + video_end=d.get("video_end", 0), + audio_end=d.get("audio_end", 0), ) @property @@ -196,6 +223,15 @@ class Request: "enable_thinking": self.enable_thinking, "trace_carrier": self.trace_carrier, "chat_template": self.chat_template, + "num_computed_tokens": self.num_computed_tokens, + "prefill_start_index": self.prefill_start_index, + "prefill_end_index": self.prefill_end_index, + "image_start": self.image_start, + "video_start": self.video_start, + "audio_start": self.audio_start, + "image_end": self.image_end, + "video_end": self.video_end, + "audio_end": self.audio_end, } add_params = [ "guided_json", diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 95f2c235d..7d4033542 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -129,6 +129,7 @@ class ResourceManagerV1(ResourceManager): return can_schedule def _get_num_new_tokens(self, request, token_budget): + # TODO: set condition to new _get_num_new_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) @@ -136,75 +137,104 @@ class ResourceManagerV1(ResourceManager): return num_new_tokens inputs = request.multimodal_inputs - request.with_image = False - # Compatible with scenarios without images and videos. - if inputs["images"] is None: - return num_new_tokens + if ( + inputs["image_feature_urls"] is not None + or inputs["video_feature_urls"] is not None + or inputs["audio_feature_urls"] is not None + ): + pre_end_idx = request.num_computed_tokens + new_end_idx = pre_end_idx + num_new_tokens + # start + start_patch_idx = inputs["patch_idx"][pre_end_idx] + start_patch_map = inputs["patch_map"][start_patch_idx] + request.image_start = start_patch_map["image_num"] + request.video_start = start_patch_map["video_num"] + request.audio_start = start_patch_map["audio_num"] - input_ids_lst = request.prompt_token_ids + request.output_token_ids - input_ids = paddle.to_tensor(input_ids_lst, dtype="int64") - input_ids = paddle.to_tensor(input_ids_lst, dtype="int64") - image_patch_id = inputs["image_patch_id"] + # end + end_patch_idx = inputs["patch_idx"][new_end_idx] + end_patch_map = inputs["patch_map"][end_patch_idx] + end_modal_id = end_patch_map["modal_id"] + if end_modal_id > 0: + new_end_idx = end_patch_map["end_idx"] # 当前模态结束位置 + num_new_tokens = new_end_idx - pre_end_idx - if request.multimodal_img_boundaries is None: - grid_thw = [] - for one in inputs["grid_thw"]: - if one[0] == 1: - grid_thw.append(one) + request.image_end = end_patch_map["image_num"] + request.video_end = end_patch_map["video_num"] + request.audio_end = end_patch_map["audio_num"] + elif inputs["images"] is not None and inputs["image_patch_id"] is not None and inputs["grid_thw"] is not None: + request.with_image = False + + input_ids_lst = request.prompt_token_ids + request.output_token_ids + input_ids = paddle.to_tensor(input_ids_lst, dtype="int64") + input_ids = paddle.to_tensor(input_ids_lst, dtype="int64") + image_patch_id = inputs["image_patch_id"] + + if request.multimodal_img_boundaries is None: + grid_thw = [] + for one in inputs["grid_thw"]: + if one[0] == 1: + grid_thw.append(one) + else: + grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) + + grid_thw = paddle.to_tensor(grid_thw, dtype="int64") + from fastdeploy.model_executor.ops.gpu import get_img_boundaries + + request.multimodal_img_boundaries = get_img_boundaries( + task_input_ids=input_ids, grid_thw=grid_thw, image_patch_id=image_patch_id + ).numpy() + + grid_thw = grid_thw.numpy().reshape([-1, 3]) + inputs["grid_thw"] = grid_thw + + grid_thw = inputs["grid_thw"] + img_boundaries_idx = request.multimodal_img_boundaries[0] + img_num_per_boundary = request.multimodal_img_boundaries[1] + ori_prompt_len = img_boundaries_idx[-1].item() + pre_end_idx = request.num_computed_tokens + new_end_idx = pre_end_idx + num_new_tokens + if new_end_idx < ori_prompt_len and input_ids[new_end_idx - 1] == image_patch_id: + boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item() + if boundary_idx == len(img_boundaries_idx): + new_end_idx = ori_prompt_len else: - grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) - - grid_thw = paddle.to_tensor(grid_thw, dtype="int64") - from fastdeploy.model_executor.ops.gpu import get_img_boundaries - - request.multimodal_img_boundaries = get_img_boundaries( - task_input_ids=input_ids, grid_thw=grid_thw, image_patch_id=image_patch_id - ).numpy() - - grid_thw = grid_thw.numpy().reshape([-1, 3]) - inputs["grid_thw"] = grid_thw - - grid_thw = inputs["grid_thw"] - img_boundaries_idx = request.multimodal_img_boundaries[0] - img_num_per_boundary = request.multimodal_img_boundaries[1] - ori_prompt_len = img_boundaries_idx[-1].item() - pre_end_idx = request.num_computed_tokens - new_end_idx = pre_end_idx + num_new_tokens - if new_end_idx < ori_prompt_len and input_ids[new_end_idx - 1] == image_patch_id: - boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item() - if boundary_idx == len(img_boundaries_idx): + new_end_idx = img_boundaries_idx[boundary_idx].item() + elif new_end_idx >= ori_prompt_len and paddle.sum(input_ids[pre_end_idx:new_end_idx] == image_patch_id): new_end_idx = ori_prompt_len - else: - new_end_idx = img_boundaries_idx[boundary_idx].item() - elif new_end_idx >= ori_prompt_len and paddle.sum(input_ids[pre_end_idx:new_end_idx] == image_patch_id): - new_end_idx = ori_prompt_len - num_new_tokens = new_end_idx - pre_end_idx + num_new_tokens = new_end_idx - pre_end_idx - image_mask = input_ids[pre_end_idx:new_end_idx] == image_patch_id - request.with_image = image_mask.any() - if request.with_image: - pre_boundary_idx = np.searchsorted(img_boundaries_idx, pre_end_idx, side="left").item() - if pre_boundary_idx == len(img_boundaries_idx): - request.num_image_start = img_num_per_boundary[-1] - else: - pre_boundary_idx = ( - pre_boundary_idx if pre_end_idx == img_boundaries_idx[pre_boundary_idx] else pre_boundary_idx - 1 - ) - request.num_image_start = img_num_per_boundary[pre_boundary_idx] + image_mask = input_ids[pre_end_idx:new_end_idx] == image_patch_id + request.with_image = image_mask.any() + if request.with_image: + pre_boundary_idx = np.searchsorted(img_boundaries_idx, pre_end_idx, side="left").item() + if pre_boundary_idx == len(img_boundaries_idx): + request.num_image_start = img_num_per_boundary[-1] + else: + pre_boundary_idx = ( + pre_boundary_idx + if pre_end_idx == img_boundaries_idx[pre_boundary_idx] + else pre_boundary_idx - 1 + ) + request.num_image_start = img_num_per_boundary[pre_boundary_idx] - new_boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item() - if new_boundary_idx == len(img_boundaries_idx): - request.num_image_end = img_num_per_boundary[-1] - else: - new_boundary_idx = ( - new_boundary_idx if new_end_idx == img_boundaries_idx[new_boundary_idx] else new_boundary_idx - 1 - ) - request.num_image_end = img_num_per_boundary[new_boundary_idx] + new_boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item() + if new_boundary_idx == len(img_boundaries_idx): + request.num_image_end = img_num_per_boundary[-1] + else: + new_boundary_idx = ( + new_boundary_idx + if new_end_idx == img_boundaries_idx[new_boundary_idx] + else new_boundary_idx - 1 + ) + request.num_image_end = img_num_per_boundary[new_boundary_idx] - request.image_type_ids_start = np.sum(grid_thw[: request.num_image_start, 0]) - request.image_type_ids_end = np.sum(grid_thw[: request.num_image_end, 0]) - request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1)) - request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1)) + request.image_type_ids_start = np.sum(grid_thw[: request.num_image_start, 0]) + request.image_type_ids_end = np.sum(grid_thw[: request.num_image_end, 0]) + request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1)) + request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1)) + + # Compatible with scenarios without images and videos. return num_new_tokens def exist_prefill(self, scheduled_reqs): diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 0dc8e2949..968306e77 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -30,7 +30,6 @@ from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.chat_utils import load_chat_template from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager -from fastdeploy.plugins.model_register import load_model_register_plugins from fastdeploy.utils import ( deprecated_kwargs_warning, llm_logger, @@ -80,7 +79,6 @@ class LLM: ): deprecated_kwargs_warning(**kwargs) - load_model_register_plugins() model = retrive_model_from_server(model, revision) tool_parser_plugin = kwargs.get("tool_parser_plugin") if tool_parser_plugin: diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 1a6f00b7c..aceccd837 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -55,9 +55,6 @@ from fastdeploy.metrics.metrics import ( main_process_metrics, ) from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument -from fastdeploy.plugins.model_register import load_model_register_plugins - -load_model_register_plugins() from fastdeploy.utils import ( FlexibleArgumentParser, StatefulSemaphore, @@ -532,7 +529,6 @@ def launch_controller_server(): def main(): """main函数""" - load_model_register_plugins() if args.local_data_parallel_id == 0: if not load_engine(): return diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index cebdae977..e7d1c1a9e 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -79,42 +79,52 @@ class InputPreprocessor: config = ModelConfig({"model": self.model_name_or_path}) architectures = config.architectures[0] - if not self.enable_mm: - if not ErnieArchitectures.contains_ernie_arch(architectures): - from fastdeploy.input.text_processor import DataProcessor + try: + from fastdeploy.plugins.input_processor import load_input_processor_plugins - self.processor = DataProcessor( - model_name_or_path=self.model_name_or_path, - reasoning_parser_obj=reasoning_parser_obj, - tool_parser_obj=tool_parser_obj, - ) + Processor = load_input_processor_plugins() + self.processor = Processor( + model_name_or_path=self.model_name_or_path, + ) + except: + if not self.enable_mm: + if not ErnieArchitectures.contains_ernie_arch(architectures): + from fastdeploy.input.text_processor import DataProcessor + + self.processor = DataProcessor( + model_name_or_path=self.model_name_or_path, + reasoning_parser_obj=reasoning_parser_obj, + tool_parser_obj=tool_parser_obj, + ) + else: + from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor + + self.processor = Ernie4_5Processor( + model_name_or_path=self.model_name_or_path, + reasoning_parser_obj=reasoning_parser_obj, + tool_parser_obj=tool_parser_obj, + ) else: - from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor + if ErnieArchitectures.contains_ernie_arch(architectures): + from fastdeploy.input.ernie4_5_vl_processor import ( + Ernie4_5_VLProcessor, + ) - self.processor = Ernie4_5Processor( - model_name_or_path=self.model_name_or_path, - reasoning_parser_obj=reasoning_parser_obj, - tool_parser_obj=tool_parser_obj, - ) - else: - if ErnieArchitectures.contains_ernie_arch(architectures): - from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor + self.processor = Ernie4_5_VLProcessor( + model_name_or_path=self.model_name_or_path, + limit_mm_per_prompt=self.limit_mm_per_prompt, + mm_processor_kwargs=self.mm_processor_kwargs, + reasoning_parser_obj=reasoning_parser_obj, + tool_parser_obj=tool_parser_obj, + ) + else: + from fastdeploy.input.qwen_vl_processor import QwenVLProcessor - self.processor = Ernie4_5_VLProcessor( - model_name_or_path=self.model_name_or_path, - limit_mm_per_prompt=self.limit_mm_per_prompt, - mm_processor_kwargs=self.mm_processor_kwargs, - reasoning_parser_obj=reasoning_parser_obj, - tool_parser_obj=tool_parser_obj, - ) - else: - from fastdeploy.input.qwen_vl_processor import QwenVLProcessor - - self.processor = QwenVLProcessor( - config=config, - model_name_or_path=self.model_name_or_path, - limit_mm_per_prompt=self.limit_mm_per_prompt, - mm_processor_kwargs=self.mm_processor_kwargs, - reasoning_parser_obj=reasoning_parser_obj, - ) + self.processor = QwenVLProcessor( + config=config, + model_name_or_path=self.model_name_or_path, + limit_mm_per_prompt=self.limit_mm_per_prompt, + mm_processor_kwargs=self.mm_processor_kwargs, + reasoning_parser_obj=reasoning_parser_obj, + ) return self.processor diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 9b259f40e..7ef78c37e 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -23,7 +23,7 @@ import msgpack import zmq from fastdeploy import envs -from fastdeploy.utils import llm_logger +from fastdeploy.utils import zmq_client_logger class ZmqClient: @@ -71,7 +71,7 @@ class ZmqClient: self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router.setsockopt(zmq.SNDTIMEO, -1) self.router.bind(f"ipc://{self.router_path}") - llm_logger.info(f"router path: {self.router_path}") + zmq_client_logger.info(f"router path: {self.router_path}") def send_json(self, data): """ @@ -139,17 +139,17 @@ class ZmqClient: else: result = msgpack.packb([response.to_dict() for response in data]) self.router.send_multipart([self.req_dict[req_id], b"", result]) - llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") + zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") except zmq.ZMQError as e: - llm_logger.error(f"[{req_id}] zmq error: {e}") + zmq_client_logger.error(f"[{req_id}] zmq error: {e}") self.req_dict[req_id] = -1 except Exception as e: - llm_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}") + zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}") if data[-1].finished: with self.mutex: self.req_dict.pop(req_id, None) - llm_logger.info(f"send_multipart finished, req_id: {req_id}") + zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}") def receive_json_once(self, block=False): """ @@ -164,7 +164,7 @@ class ZmqClient: return None, None except Exception as e: self.close() - llm_logger.warning(f"{e}, {str(traceback.format_exc())}") + zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}") return str(e), None def receive_pyobj_once(self, block=False): @@ -180,7 +180,7 @@ class ZmqClient: return None, None except Exception as e: self.close() - llm_logger.warning(f"{e}, {str(traceback.format_exc())}") + zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}") return str(e), None def _clear_ipc(self, name): @@ -191,7 +191,7 @@ class ZmqClient: try: os.remove(name) except OSError as e: - llm_logger.warning(f"Failed to remove IPC file {name} - {e}") + zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}") def close(self): """ @@ -201,7 +201,7 @@ class ZmqClient: return self.running = False - llm_logger.info("Closing ZMQ connection...") + zmq_client_logger.info("Closing ZMQ connection...") try: if hasattr(self, "socket") and not self.socket.closed: self.socket.close() @@ -215,7 +215,7 @@ class ZmqClient: self._clear_ipc(self.file_name) self._clear_ipc(self.router_path) except Exception as e: - llm_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}") + zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}") return def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index ec31c4753..eb8f4b5f8 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -79,6 +79,8 @@ class ForwardMeta: forward_mode: ForwardMode = ForwardMode.MIXED # Attention mask attn_mask: Optional[paddle.Tensor] = None + # Attention mask offset + attn_mask_offsets: Optional[paddle.Tensor] = None # Decoder batch id. Used by attention backend. decoder_batch_ids: Optional[paddle.Tensor] = None # Tile ID for each batch of the decoder. Used by attention backend. diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 3edba32f4..551e19e59 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -98,7 +98,9 @@ class AppendAttentionBackend(AttentionBackend): self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( + fd_config.model_config, "use_3d_rope", False + ) self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method: str = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None @@ -140,6 +142,7 @@ class AppendAttentionBackend(AttentionBackend): metadata.block_tables = forward_meta.block_tables metadata.rotary_embs = forward_meta.rotary_embs metadata.attn_mask = forward_meta.attn_mask + metadata.mask_offset = forward_meta.attn_mask_offsets metadata.pre_caches_length = forward_meta.pre_caches_length ( metadata.encoder_batch_ids, diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 0065b1155..ed2880f8e 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -721,7 +721,8 @@ class RowParallelLinear(LinearBase): add_bias=add_bias, skip_quant=skip_quant, ) - + if add_bias: + assert with_bias, "with_bias must be True when add_bias is True." assert self.quant_method is not None self.quant_method.create_weights( self, @@ -753,7 +754,8 @@ class RowParallelLinear(LinearBase): if self.reduce_results and self.nranks > 1: tensor_model_parallel_all_reduce(out, self.tp_group) - + if not self.fd_config.quant_config and self.add_bias: + out = paddle.add(out, self.bias) return out diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 261aaf620..9659aec7d 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -54,8 +54,8 @@ def get_moe_scores( scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, - n_group, - topk_group, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, top_k, routed_scaling_factor, ) diff --git a/fastdeploy/model_executor/models/__init__.py b/fastdeploy/model_executor/models/__init__.py index a9ac07f72..e96d65b18 100644 --- a/fastdeploy/model_executor/models/__init__.py +++ b/fastdeploy/model_executor/models/__init__.py @@ -21,6 +21,8 @@ from pathlib import Path from paddleformers.transformers import PretrainedModel +from fastdeploy.plugins.model_register import load_model_register_plugins + from .model_base import ModelForCasualLM, ModelRegistry @@ -59,3 +61,5 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode auto_models_registry(os.path.dirname(__file__)) + +load_model_register_plugins() diff --git a/fastdeploy/plugins/__init__.py b/fastdeploy/plugins/__init__.py index 844d319cc..6df06f763 100644 --- a/fastdeploy/plugins/__init__.py +++ b/fastdeploy/plugins/__init__.py @@ -14,7 +14,8 @@ # limitations under the License. """ +from .input_processor import load_input_processor_plugins from .model_register import load_model_register_plugins from .model_runner import load_model_runner_plugins -__all__ = ["load_model_register_plugins", "load_model_runner_plugins"] +__all__ = ["load_model_register_plugins", "load_model_runner_plugins", "load_input_processor_plugins"] diff --git a/fastdeploy/plugins/input_processor/__init__.py b/fastdeploy/plugins/input_processor/__init__.py new file mode 100644 index 000000000..d7c698f44 --- /dev/null +++ b/fastdeploy/plugins/input_processor/__init__.py @@ -0,0 +1,27 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from fastdeploy.plugins.utils import load_plugins_by_group + +# make sure one process only loads plugins once +PLUGINS_GROUP = "fastdeploy.input_processor_plugins" + + +def load_input_processor_plugins(): + """load_input_processor_plugins""" + plugins = load_plugins_by_group(group=PLUGINS_GROUP) + assert len(plugins) <= 1, "Most one plugin is allowed to be loaded." + return next(iter(plugins.values()))() diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index 33508603d..e3e3f4e38 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -56,9 +56,6 @@ class RolloutModel(nn.Layer): def _init_model(self) -> nn.Layer: """Load model from loader based on config.""" context = paddle.LazyGuard() - from fastdeploy.plugins.model_register import load_model_register_plugins - - load_model_register_plugins() architectures = f"{self.fd_config.model_config.architectures[0]}RL" with context: model_cls = ModelRegistry.get_class(architectures) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 683c11e6d..ecefd87af 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -769,3 +769,4 @@ scheduler_logger = get_logger("scheduler", "scheduler.log") api_server_logger = get_logger("api_server", "api_server.log") console_logger = get_logger("console", "console.log", print_to_console=True) spec_logger = get_logger("speculate", "speculate.log") +zmq_client_logger = get_logger("zmq_client", "zmq_client.log") diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 0a4dd0137..452a7d97b 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -105,7 +105,8 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]: def update_fd_config_for_mm(fd_config: FDConfig) -> None: - if fd_config.model_config.enable_mm: + architectures = fd_config.model_config.architectures + if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures): tokenizer = Ernie4_5Tokenizer.from_pretrained( fd_config.model_config.model, model_max_length=fd_config.parallel_config.max_model_len, @@ -771,7 +772,4 @@ def run_worker_proc() -> None: if __name__ == "__main__": - from fastdeploy.plugins.model_register import load_model_register_plugins - - load_model_register_plugins() run_worker_proc() diff --git a/tests/plugins/test_model_registry.py b/tests/plugins/test_model_registry.py index f58399537..01e15675a 100644 --- a/tests/plugins/test_model_registry.py +++ b/tests/plugins/test_model_registry.py @@ -14,33 +14,15 @@ import unittest -from fastdeploy import ModelRegistry from fastdeploy.plugins import load_model_register_plugins class TestModelRegistryPlugins(unittest.TestCase): def test_plugin_registers_one_architecture(self): """Test that loading plugins registers exactly one new architecture.""" - initial_archs = set(ModelRegistry.get_supported_archs()) - print("Supported architectures before loading plugins:", sorted(initial_archs)) - # Load plugins load_model_register_plugins() - final_archs = set(ModelRegistry.get_supported_archs()) - print("Supported architectures after loading plugins:", sorted(final_archs)) - - added_archs = final_archs - initial_archs - added_count = len(added_archs) - - # verify - self.assertEqual( - added_count, - 1, - f"Expected exactly 1 new architecture to be registered by plugins, " - f"but {added_count} were added: {sorted(added_archs)}", - ) - if __name__ == "__main__": unittest.main()