[Feature] Multimodal Model P / D Separation (#5323)

* RouterArgs port str -> int

* fix race condition [is_fetching] causing multiple fetch requests

* bugfix: Delete duplicate input_ids tensor creation

* mm pd splitwise json -> pickle5; multimodal_inputs only pos id;
debuglog f to %s

* fix ENABLE_V1_KVCACHE_SCHEDULER=0 mm model lack pos_id, ...

* update cr

* Apply suggestions from code review

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* pre-commit fix

* rm multimodal_inputs deepcopy & fix rdma_cache_transfer.py tpsize=0

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Daci
2025-12-09 10:47:42 +08:00
committed by GitHub
parent a8ffc22032
commit 2f208db4e9
5 changed files with 80 additions and 33 deletions

View File

@@ -14,7 +14,7 @@
# limitations under the License.
"""
import json
import pickle
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
@@ -94,12 +94,12 @@ class SplitwiseConnector:
if not socks:
continue
else:
self.logger.debug(f"start_receiver: receive {socks}")
self.logger.debug("start_receiver: receive %s", socks)
frames = self.router_socket.recv_multipart()
self.logger.debug(f"start_receiver: frames: {frames}")
message = frames[-1]
self.io_executor.submit(self._process_message, message)
self.logger.debug("start_receiver: frames: %s", frames)
# message = frames[-1]
self.io_executor.submit(self._process_message, frames)
time.sleep(0.001)
else:
time.sleep(5)
@@ -147,7 +147,10 @@ class SplitwiseConnector:
try:
self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}")
sock = self._get_push_socket(addr)
sock.send_multipart([b"", message])
sock.send_multipart(message)
self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError:
self.logger.warning(f"_send_message: Connection to {addr} not established")
except zmq.Again:
@@ -383,21 +386,51 @@ class SplitwiseConnector:
if msg_type == "decode" or msg_type == "prefill":
payload = [output.to_dict() for output in payload]
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
return json_data
# Prepare data
data = {"type": msg_type, "payload": payload}
def _deserialize_message(self, data: bytes):
# Pickle protocol 5 supports extracting large arrays (buffers)
buffers = []
# Serialize main data, strip large arrays as references into buffers
main_bytes = pickle.dumps(data, protocol=5, buffer_callback=buffers.append)
# Serialize using pickle protocol 5 which provides efficient handling
# of large numpy arrays through out-of-band buffers.
# Returns: [main_bytes, buffer1, buffer2, ...]
# where main_bytes contains the serialized structure and buffers contain
# the actual array data extracted for efficient transmission.
return [main_bytes] + buffers
# JSON反序列化
message = json.loads(data.decode("utf-8"))
def _deserialize_message(self, frames: List[bytes]):
"""
Deserialize message from ZMQ frames using pickle protocol 5.
Args:
frames: List of byte frames where:
- frames[0]: Identity frame (sender address)
- frames[1]: Main pickled data structure
- frames[2:]: Out-of-band buffers (numpy arrays)
Returns:
Tuple of (message_type: str, payload: Any)
"""
# identity = frames[0]
if len(frames) < 2:
raise ValueError(f"Received frames too short: expected at least 2 frames but got {len(frames)}")
main_bytes = frames[1]
buffers = frames[2:]
# Restore data, pickle will automatically fill buffers back into numpy arrays
message = pickle.loads(main_bytes, buffers=buffers)
return message["type"], message["payload"]
def _process_message(self, message: bytes):
def _process_message(self, frames: List[bytes]):
"""
process message
"""
try:
msg_type, payload = self._deserialize_message(message)
msg_type, payload = self._deserialize_message(frames)
self.logger.info(f"_process_message: {msg_type}")
if msg_type == "prefill":