diff --git a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py index f90a5d232..1892d8edd 100644 --- a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py +++ b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py @@ -61,7 +61,7 @@ class RDMACommManager: f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}" ) - def connect(self, ip, port, tp_size): + def connect(self, ip, port, tp_size=0): """ Connect to remote gpu and write cache. """ diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index dbc94bf8a..e64716fe8 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -16,7 +16,6 @@ from __future__ import annotations -import copy import time import traceback from dataclasses import asdict, dataclass, fields @@ -274,20 +273,6 @@ class Request: def to_dict(self) -> dict: """convert Request into a serializable dict""" - multimodal_inputs = copy.deepcopy(self.multimodal_inputs) - if ( - isinstance(multimodal_inputs, dict) - and isinstance(multimodal_inputs.get("mm_positions"), list) - and len(multimodal_inputs["mm_positions"]) > 0 - ): - # if mm_positions is ImagePosition, convert to dict - try: - for i, mm_pos in enumerate(multimodal_inputs["mm_positions"]): - multimodal_inputs["mm_positions"][i] = ( - asdict(mm_pos) if isinstance(mm_pos, ImagePosition) else mm_pos - ) - except Exception as e: - data_processor_logger.error(f"Convert ImagePosition to dict error: {e}, {str(traceback.format_exc())}") data = { "request_id": self.request_id, @@ -299,7 +284,6 @@ class Request: "history": self.history, "tools": self.tools, "eos_token_ids": self.eos_token_ids, - "multimodal_inputs": multimodal_inputs, "multimodal_data": self.multimodal_data, "disable_chat_template": self.disable_chat_template, "disaggregate_info": self.disaggregate_info, @@ -319,6 +303,21 @@ class Request: "audio_end": self.audio_end, "ic_req_data": self.ic_req_data, } + + # During multimodal PD separation, position_ids are required + if isinstance(self.multimodal_inputs, dict): + # Optimize multimodal data transfer during PD separation: + # - V1 mode (ENABLE_V1_KVCACHE_SCHEDULER=1): Only position_ids needed for decode nodes + # - V0 mode (ENABLE_V1_KVCACHE_SCHEDULER=0): Full field set required for compatibility + # This filtering significantly reduces serialized data size for large numpy arrays + allowed_keys = {"position_ids"} + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + allowed_keys.update(["input_ids", "token_type_ids", "images", "image_type_ids", "grid_thw"]) + + data["multimodal_inputs"] = { + key: value for key, value in self.multimodal_inputs.items() if key in allowed_keys + } + add_params = [ "guided_json", "guided_regex", diff --git a/fastdeploy/scheduler/splitwise_scheduler.py b/fastdeploy/scheduler/splitwise_scheduler.py index 289515d1b..350fbf173 100644 --- a/fastdeploy/scheduler/splitwise_scheduler.py +++ b/fastdeploy/scheduler/splitwise_scheduler.py @@ -17,6 +17,7 @@ import copy import hashlib import math +import pickle import random import threading import time @@ -545,7 +546,7 @@ class APIScheduler: pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}" req_dict = req.to_dict() req_dict["group"] = group - req_str = orjson.dumps(req_dict) + req_str = pickle.dumps(req_dict, protocol=5) # logger.info(f"Schedule Req {req_str}") self.client.lpush(dkey, req_str) self.client.lpush(pkey, req_str) @@ -795,7 +796,7 @@ class InferScheduler: reqs = [ret[1]] for req_str in reqs: - req = orjson.loads(req_str) + req = pickle.loads(req_str) group = req.get("group", "") req = Request.from_dict(req) writer_idx = select_writer(req) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index d82fbec84..3aafe3dbe 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -14,7 +14,7 @@ # limitations under the License. """ -import json +import pickle import time import traceback from concurrent.futures import ThreadPoolExecutor @@ -94,12 +94,12 @@ class SplitwiseConnector: if not socks: continue else: - self.logger.debug(f"start_receiver: receive {socks}") + self.logger.debug("start_receiver: receive %s", socks) frames = self.router_socket.recv_multipart() - self.logger.debug(f"start_receiver: frames: {frames}") - message = frames[-1] - self.io_executor.submit(self._process_message, message) + self.logger.debug("start_receiver: frames: %s", frames) + # message = frames[-1] + self.io_executor.submit(self._process_message, frames) time.sleep(0.001) else: time.sleep(5) @@ -147,7 +147,10 @@ class SplitwiseConnector: try: self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}") sock = self._get_push_socket(addr) - sock.send_multipart([b"", message]) + sock.send_multipart(message) + + self.logger.info(f"Sent {msg_type} to {addr}") + except ConnectionError: self.logger.warning(f"_send_message: Connection to {addr} not established") except zmq.Again: @@ -383,21 +386,51 @@ class SplitwiseConnector: if msg_type == "decode" or msg_type == "prefill": payload = [output.to_dict() for output in payload] - json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8") - return json_data + # Prepare data + data = {"type": msg_type, "payload": payload} - def _deserialize_message(self, data: bytes): + # Pickle protocol 5 supports extracting large arrays (buffers) + buffers = [] + # Serialize main data, strip large arrays as references into buffers + main_bytes = pickle.dumps(data, protocol=5, buffer_callback=buffers.append) + # Serialize using pickle protocol 5 which provides efficient handling + # of large numpy arrays through out-of-band buffers. + # Returns: [main_bytes, buffer1, buffer2, ...] + # where main_bytes contains the serialized structure and buffers contain + # the actual array data extracted for efficient transmission. + return [main_bytes] + buffers - # JSON反序列化 - message = json.loads(data.decode("utf-8")) + def _deserialize_message(self, frames: List[bytes]): + """ + Deserialize message from ZMQ frames using pickle protocol 5. + + Args: + frames: List of byte frames where: + - frames[0]: Identity frame (sender address) + - frames[1]: Main pickled data structure + - frames[2:]: Out-of-band buffers (numpy arrays) + + Returns: + Tuple of (message_type: str, payload: Any) + """ + # identity = frames[0] + + if len(frames) < 2: + raise ValueError(f"Received frames too short: expected at least 2 frames but got {len(frames)}") + + main_bytes = frames[1] + buffers = frames[2:] + + # Restore data, pickle will automatically fill buffers back into numpy arrays + message = pickle.loads(main_bytes, buffers=buffers) return message["type"], message["payload"] - def _process_message(self, message: bytes): + def _process_message(self, frames: List[bytes]): """ process message """ try: - msg_type, payload = self._deserialize_message(message) + msg_type, payload = self._deserialize_message(frames) self.logger.info(f"_process_message: {msg_type}") if msg_type == "prefill": diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2a0248894..b67ed7c5c 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -826,6 +826,20 @@ class GPUModelRunner(ModelRunnerBase): dtype="int64", ) self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token + if self.enable_mm: + # Fix for V0 mode: Add position encoding for decode nodes in multimodal models + # to prevent garbled output. Position_ids are transmitted from prefill nodes. + if ( + "position_ids" in request.multimodal_inputs + and request.multimodal_inputs["position_ids"] is not None + ): + position_ids = paddle.to_tensor( + request.multimodal_inputs["position_ids"], + dtype="int64", + ) + self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( + position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]] + )[0] else: self.share_inputs["pre_ids"][idx : idx + 1] = -1 self.share_inputs["step_idx"][idx : idx + 1] = 0 @@ -2709,7 +2723,7 @@ class GPUModelRunner(ModelRunnerBase): token_type_ids = one["token_type_ids"][np.newaxis, :] token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) - if one["images"] is not None: + if "images" in one and one["images"] is not None: image_type_ids = one["image_type_ids"][np.newaxis, :] images = one["images"] image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)