[Feature] Multimodal Model P / D Separation (#5323)

* RouterArgs port str -> int

* fix race condition [is_fetching] causing multiple fetch requests

* bugfix: Delete duplicate input_ids tensor creation

* mm pd splitwise json -> pickle5; multimodal_inputs only pos id;
debuglog f to %s

* fix ENABLE_V1_KVCACHE_SCHEDULER=0 mm model lack pos_id, ...

* update cr

* Apply suggestions from code review

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* pre-commit fix

* rm multimodal_inputs deepcopy & fix rdma_cache_transfer.py tpsize=0

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Daci
2025-12-09 10:47:42 +08:00
committed by GitHub
parent a8ffc22032
commit 2f208db4e9
5 changed files with 80 additions and 33 deletions

View File

@@ -61,7 +61,7 @@ class RDMACommManager:
f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}"
)
def connect(self, ip, port, tp_size):
def connect(self, ip, port, tp_size=0):
"""
Connect to remote gpu and write cache.
"""

View File

@@ -16,7 +16,6 @@
from __future__ import annotations
import copy
import time
import traceback
from dataclasses import asdict, dataclass, fields
@@ -274,20 +273,6 @@ class Request:
def to_dict(self) -> dict:
"""convert Request into a serializable dict"""
multimodal_inputs = copy.deepcopy(self.multimodal_inputs)
if (
isinstance(multimodal_inputs, dict)
and isinstance(multimodal_inputs.get("mm_positions"), list)
and len(multimodal_inputs["mm_positions"]) > 0
):
# if mm_positions is ImagePosition, convert to dict
try:
for i, mm_pos in enumerate(multimodal_inputs["mm_positions"]):
multimodal_inputs["mm_positions"][i] = (
asdict(mm_pos) if isinstance(mm_pos, ImagePosition) else mm_pos
)
except Exception as e:
data_processor_logger.error(f"Convert ImagePosition to dict error: {e}, {str(traceback.format_exc())}")
data = {
"request_id": self.request_id,
@@ -299,7 +284,6 @@ class Request:
"history": self.history,
"tools": self.tools,
"eos_token_ids": self.eos_token_ids,
"multimodal_inputs": multimodal_inputs,
"multimodal_data": self.multimodal_data,
"disable_chat_template": self.disable_chat_template,
"disaggregate_info": self.disaggregate_info,
@@ -319,6 +303,21 @@ class Request:
"audio_end": self.audio_end,
"ic_req_data": self.ic_req_data,
}
# During multimodal PD separation, position_ids are required
if isinstance(self.multimodal_inputs, dict):
# Optimize multimodal data transfer during PD separation:
# - V1 mode (ENABLE_V1_KVCACHE_SCHEDULER=1): Only position_ids needed for decode nodes
# - V0 mode (ENABLE_V1_KVCACHE_SCHEDULER=0): Full field set required for compatibility
# This filtering significantly reduces serialized data size for large numpy arrays
allowed_keys = {"position_ids"}
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
allowed_keys.update(["input_ids", "token_type_ids", "images", "image_type_ids", "grid_thw"])
data["multimodal_inputs"] = {
key: value for key, value in self.multimodal_inputs.items() if key in allowed_keys
}
add_params = [
"guided_json",
"guided_regex",

View File

@@ -17,6 +17,7 @@
import copy
import hashlib
import math
import pickle
import random
import threading
import time
@@ -545,7 +546,7 @@ class APIScheduler:
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
req_dict = req.to_dict()
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
req_str = pickle.dumps(req_dict, protocol=5)
# logger.info(f"Schedule Req {req_str}")
self.client.lpush(dkey, req_str)
self.client.lpush(pkey, req_str)
@@ -795,7 +796,7 @@ class InferScheduler:
reqs = [ret[1]]
for req_str in reqs:
req = orjson.loads(req_str)
req = pickle.loads(req_str)
group = req.get("group", "")
req = Request.from_dict(req)
writer_idx = select_writer(req)

View File

@@ -14,7 +14,7 @@
# limitations under the License.
"""
import json
import pickle
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
@@ -94,12 +94,12 @@ class SplitwiseConnector:
if not socks:
continue
else:
self.logger.debug(f"start_receiver: receive {socks}")
self.logger.debug("start_receiver: receive %s", socks)
frames = self.router_socket.recv_multipart()
self.logger.debug(f"start_receiver: frames: {frames}")
message = frames[-1]
self.io_executor.submit(self._process_message, message)
self.logger.debug("start_receiver: frames: %s", frames)
# message = frames[-1]
self.io_executor.submit(self._process_message, frames)
time.sleep(0.001)
else:
time.sleep(5)
@@ -147,7 +147,10 @@ class SplitwiseConnector:
try:
self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}")
sock = self._get_push_socket(addr)
sock.send_multipart([b"", message])
sock.send_multipart(message)
self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError:
self.logger.warning(f"_send_message: Connection to {addr} not established")
except zmq.Again:
@@ -383,21 +386,51 @@ class SplitwiseConnector:
if msg_type == "decode" or msg_type == "prefill":
payload = [output.to_dict() for output in payload]
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
return json_data
# Prepare data
data = {"type": msg_type, "payload": payload}
def _deserialize_message(self, data: bytes):
# Pickle protocol 5 supports extracting large arrays (buffers)
buffers = []
# Serialize main data, strip large arrays as references into buffers
main_bytes = pickle.dumps(data, protocol=5, buffer_callback=buffers.append)
# Serialize using pickle protocol 5 which provides efficient handling
# of large numpy arrays through out-of-band buffers.
# Returns: [main_bytes, buffer1, buffer2, ...]
# where main_bytes contains the serialized structure and buffers contain
# the actual array data extracted for efficient transmission.
return [main_bytes] + buffers
# JSON反序列化
message = json.loads(data.decode("utf-8"))
def _deserialize_message(self, frames: List[bytes]):
"""
Deserialize message from ZMQ frames using pickle protocol 5.
Args:
frames: List of byte frames where:
- frames[0]: Identity frame (sender address)
- frames[1]: Main pickled data structure
- frames[2:]: Out-of-band buffers (numpy arrays)
Returns:
Tuple of (message_type: str, payload: Any)
"""
# identity = frames[0]
if len(frames) < 2:
raise ValueError(f"Received frames too short: expected at least 2 frames but got {len(frames)}")
main_bytes = frames[1]
buffers = frames[2:]
# Restore data, pickle will automatically fill buffers back into numpy arrays
message = pickle.loads(main_bytes, buffers=buffers)
return message["type"], message["payload"]
def _process_message(self, message: bytes):
def _process_message(self, frames: List[bytes]):
"""
process message
"""
try:
msg_type, payload = self._deserialize_message(message)
msg_type, payload = self._deserialize_message(frames)
self.logger.info(f"_process_message: {msg_type}")
if msg_type == "prefill":

View File

@@ -826,6 +826,20 @@ class GPUModelRunner(ModelRunnerBase):
dtype="int64",
)
self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
if self.enable_mm:
# Fix for V0 mode: Add position encoding for decode nodes in multimodal models
# to prevent garbled output. Position_ids are transmitted from prefill nodes.
if (
"position_ids" in request.multimodal_inputs
and request.multimodal_inputs["position_ids"] is not None
):
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64",
)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]]
)[0]
else:
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["step_idx"][idx : idx + 1] = 0
@@ -2709,7 +2723,7 @@ class GPUModelRunner(ModelRunnerBase):
token_type_ids = one["token_type_ids"][np.newaxis, :]
token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if one["images"] is not None:
if "images" in one and one["images"] is not None:
image_type_ids = one["image_type_ids"][np.newaxis, :]
images = one["images"]
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)