mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Multimodal Model P / D Separation (#5323)
* RouterArgs port str -> int * fix race condition [is_fetching] causing multiple fetch requests * bugfix: Delete duplicate input_ids tensor creation * mm pd splitwise json -> pickle5; multimodal_inputs only pos id; debuglog f to %s * fix ENABLE_V1_KVCACHE_SCHEDULER=0 mm model lack pos_id, ... * update cr * Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * pre-commit fix * rm multimodal_inputs deepcopy & fix rdma_cache_transfer.py tpsize=0 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -94,12 +94,12 @@ class SplitwiseConnector:
|
||||
if not socks:
|
||||
continue
|
||||
else:
|
||||
self.logger.debug(f"start_receiver: receive {socks}")
|
||||
self.logger.debug("start_receiver: receive %s", socks)
|
||||
|
||||
frames = self.router_socket.recv_multipart()
|
||||
self.logger.debug(f"start_receiver: frames: {frames}")
|
||||
message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, message)
|
||||
self.logger.debug("start_receiver: frames: %s", frames)
|
||||
# message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, frames)
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
time.sleep(5)
|
||||
@@ -147,7 +147,10 @@ class SplitwiseConnector:
|
||||
try:
|
||||
self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}")
|
||||
sock = self._get_push_socket(addr)
|
||||
sock.send_multipart([b"", message])
|
||||
sock.send_multipart(message)
|
||||
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
|
||||
except ConnectionError:
|
||||
self.logger.warning(f"_send_message: Connection to {addr} not established")
|
||||
except zmq.Again:
|
||||
@@ -383,21 +386,51 @@ class SplitwiseConnector:
|
||||
if msg_type == "decode" or msg_type == "prefill":
|
||||
payload = [output.to_dict() for output in payload]
|
||||
|
||||
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
|
||||
return json_data
|
||||
# Prepare data
|
||||
data = {"type": msg_type, "payload": payload}
|
||||
|
||||
def _deserialize_message(self, data: bytes):
|
||||
# Pickle protocol 5 supports extracting large arrays (buffers)
|
||||
buffers = []
|
||||
# Serialize main data, strip large arrays as references into buffers
|
||||
main_bytes = pickle.dumps(data, protocol=5, buffer_callback=buffers.append)
|
||||
# Serialize using pickle protocol 5 which provides efficient handling
|
||||
# of large numpy arrays through out-of-band buffers.
|
||||
# Returns: [main_bytes, buffer1, buffer2, ...]
|
||||
# where main_bytes contains the serialized structure and buffers contain
|
||||
# the actual array data extracted for efficient transmission.
|
||||
return [main_bytes] + buffers
|
||||
|
||||
# JSON反序列化
|
||||
message = json.loads(data.decode("utf-8"))
|
||||
def _deserialize_message(self, frames: List[bytes]):
|
||||
"""
|
||||
Deserialize message from ZMQ frames using pickle protocol 5.
|
||||
|
||||
Args:
|
||||
frames: List of byte frames where:
|
||||
- frames[0]: Identity frame (sender address)
|
||||
- frames[1]: Main pickled data structure
|
||||
- frames[2:]: Out-of-band buffers (numpy arrays)
|
||||
|
||||
Returns:
|
||||
Tuple of (message_type: str, payload: Any)
|
||||
"""
|
||||
# identity = frames[0]
|
||||
|
||||
if len(frames) < 2:
|
||||
raise ValueError(f"Received frames too short: expected at least 2 frames but got {len(frames)}")
|
||||
|
||||
main_bytes = frames[1]
|
||||
buffers = frames[2:]
|
||||
|
||||
# Restore data, pickle will automatically fill buffers back into numpy arrays
|
||||
message = pickle.loads(main_bytes, buffers=buffers)
|
||||
return message["type"], message["payload"]
|
||||
|
||||
def _process_message(self, message: bytes):
|
||||
def _process_message(self, frames: List[bytes]):
|
||||
"""
|
||||
process message
|
||||
"""
|
||||
try:
|
||||
msg_type, payload = self._deserialize_message(message)
|
||||
msg_type, payload = self._deserialize_message(frames)
|
||||
self.logger.info(f"_process_message: {msg_type}")
|
||||
|
||||
if msg_type == "prefill":
|
||||
|
||||
Reference in New Issue
Block a user