mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Multimodal Model P / D Separation (#5323)
* RouterArgs port str -> int * fix race condition [is_fetching] causing multiple fetch requests * bugfix: Delete duplicate input_ids tensor creation * mm pd splitwise json -> pickle5; multimodal_inputs only pos id; debuglog f to %s * fix ENABLE_V1_KVCACHE_SCHEDULER=0 mm model lack pos_id, ... * update cr * Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * pre-commit fix * rm multimodal_inputs deepcopy & fix rdma_cache_transfer.py tpsize=0 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -61,7 +61,7 @@ class RDMACommManager:
|
||||
f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}"
|
||||
)
|
||||
|
||||
def connect(self, ip, port, tp_size):
|
||||
def connect(self, ip, port, tp_size=0):
|
||||
"""
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
@@ -274,20 +273,6 @@ class Request:
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""convert Request into a serializable dict"""
|
||||
multimodal_inputs = copy.deepcopy(self.multimodal_inputs)
|
||||
if (
|
||||
isinstance(multimodal_inputs, dict)
|
||||
and isinstance(multimodal_inputs.get("mm_positions"), list)
|
||||
and len(multimodal_inputs["mm_positions"]) > 0
|
||||
):
|
||||
# if mm_positions is ImagePosition, convert to dict
|
||||
try:
|
||||
for i, mm_pos in enumerate(multimodal_inputs["mm_positions"]):
|
||||
multimodal_inputs["mm_positions"][i] = (
|
||||
asdict(mm_pos) if isinstance(mm_pos, ImagePosition) else mm_pos
|
||||
)
|
||||
except Exception as e:
|
||||
data_processor_logger.error(f"Convert ImagePosition to dict error: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
data = {
|
||||
"request_id": self.request_id,
|
||||
@@ -299,7 +284,6 @@ class Request:
|
||||
"history": self.history,
|
||||
"tools": self.tools,
|
||||
"eos_token_ids": self.eos_token_ids,
|
||||
"multimodal_inputs": multimodal_inputs,
|
||||
"multimodal_data": self.multimodal_data,
|
||||
"disable_chat_template": self.disable_chat_template,
|
||||
"disaggregate_info": self.disaggregate_info,
|
||||
@@ -319,6 +303,21 @@ class Request:
|
||||
"audio_end": self.audio_end,
|
||||
"ic_req_data": self.ic_req_data,
|
||||
}
|
||||
|
||||
# During multimodal PD separation, position_ids are required
|
||||
if isinstance(self.multimodal_inputs, dict):
|
||||
# Optimize multimodal data transfer during PD separation:
|
||||
# - V1 mode (ENABLE_V1_KVCACHE_SCHEDULER=1): Only position_ids needed for decode nodes
|
||||
# - V0 mode (ENABLE_V1_KVCACHE_SCHEDULER=0): Full field set required for compatibility
|
||||
# This filtering significantly reduces serialized data size for large numpy arrays
|
||||
allowed_keys = {"position_ids"}
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
allowed_keys.update(["input_ids", "token_type_ids", "images", "image_type_ids", "grid_thw"])
|
||||
|
||||
data["multimodal_inputs"] = {
|
||||
key: value for key, value in self.multimodal_inputs.items() if key in allowed_keys
|
||||
}
|
||||
|
||||
add_params = [
|
||||
"guided_json",
|
||||
"guided_regex",
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import copy
|
||||
import hashlib
|
||||
import math
|
||||
import pickle
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
@@ -545,7 +546,7 @@ class APIScheduler:
|
||||
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
|
||||
req_dict = req.to_dict()
|
||||
req_dict["group"] = group
|
||||
req_str = orjson.dumps(req_dict)
|
||||
req_str = pickle.dumps(req_dict, protocol=5)
|
||||
# logger.info(f"Schedule Req {req_str}")
|
||||
self.client.lpush(dkey, req_str)
|
||||
self.client.lpush(pkey, req_str)
|
||||
@@ -795,7 +796,7 @@ class InferScheduler:
|
||||
reqs = [ret[1]]
|
||||
|
||||
for req_str in reqs:
|
||||
req = orjson.loads(req_str)
|
||||
req = pickle.loads(req_str)
|
||||
group = req.get("group", "")
|
||||
req = Request.from_dict(req)
|
||||
writer_idx = select_writer(req)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -94,12 +94,12 @@ class SplitwiseConnector:
|
||||
if not socks:
|
||||
continue
|
||||
else:
|
||||
self.logger.debug(f"start_receiver: receive {socks}")
|
||||
self.logger.debug("start_receiver: receive %s", socks)
|
||||
|
||||
frames = self.router_socket.recv_multipart()
|
||||
self.logger.debug(f"start_receiver: frames: {frames}")
|
||||
message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, message)
|
||||
self.logger.debug("start_receiver: frames: %s", frames)
|
||||
# message = frames[-1]
|
||||
self.io_executor.submit(self._process_message, frames)
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
time.sleep(5)
|
||||
@@ -147,7 +147,10 @@ class SplitwiseConnector:
|
||||
try:
|
||||
self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}")
|
||||
sock = self._get_push_socket(addr)
|
||||
sock.send_multipart([b"", message])
|
||||
sock.send_multipart(message)
|
||||
|
||||
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||
|
||||
except ConnectionError:
|
||||
self.logger.warning(f"_send_message: Connection to {addr} not established")
|
||||
except zmq.Again:
|
||||
@@ -383,21 +386,51 @@ class SplitwiseConnector:
|
||||
if msg_type == "decode" or msg_type == "prefill":
|
||||
payload = [output.to_dict() for output in payload]
|
||||
|
||||
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
|
||||
return json_data
|
||||
# Prepare data
|
||||
data = {"type": msg_type, "payload": payload}
|
||||
|
||||
def _deserialize_message(self, data: bytes):
|
||||
# Pickle protocol 5 supports extracting large arrays (buffers)
|
||||
buffers = []
|
||||
# Serialize main data, strip large arrays as references into buffers
|
||||
main_bytes = pickle.dumps(data, protocol=5, buffer_callback=buffers.append)
|
||||
# Serialize using pickle protocol 5 which provides efficient handling
|
||||
# of large numpy arrays through out-of-band buffers.
|
||||
# Returns: [main_bytes, buffer1, buffer2, ...]
|
||||
# where main_bytes contains the serialized structure and buffers contain
|
||||
# the actual array data extracted for efficient transmission.
|
||||
return [main_bytes] + buffers
|
||||
|
||||
# JSON反序列化
|
||||
message = json.loads(data.decode("utf-8"))
|
||||
def _deserialize_message(self, frames: List[bytes]):
|
||||
"""
|
||||
Deserialize message from ZMQ frames using pickle protocol 5.
|
||||
|
||||
Args:
|
||||
frames: List of byte frames where:
|
||||
- frames[0]: Identity frame (sender address)
|
||||
- frames[1]: Main pickled data structure
|
||||
- frames[2:]: Out-of-band buffers (numpy arrays)
|
||||
|
||||
Returns:
|
||||
Tuple of (message_type: str, payload: Any)
|
||||
"""
|
||||
# identity = frames[0]
|
||||
|
||||
if len(frames) < 2:
|
||||
raise ValueError(f"Received frames too short: expected at least 2 frames but got {len(frames)}")
|
||||
|
||||
main_bytes = frames[1]
|
||||
buffers = frames[2:]
|
||||
|
||||
# Restore data, pickle will automatically fill buffers back into numpy arrays
|
||||
message = pickle.loads(main_bytes, buffers=buffers)
|
||||
return message["type"], message["payload"]
|
||||
|
||||
def _process_message(self, message: bytes):
|
||||
def _process_message(self, frames: List[bytes]):
|
||||
"""
|
||||
process message
|
||||
"""
|
||||
try:
|
||||
msg_type, payload = self._deserialize_message(message)
|
||||
msg_type, payload = self._deserialize_message(frames)
|
||||
self.logger.info(f"_process_message: {msg_type}")
|
||||
|
||||
if msg_type == "prefill":
|
||||
|
||||
@@ -826,6 +826,20 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
dtype="int64",
|
||||
)
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
|
||||
if self.enable_mm:
|
||||
# Fix for V0 mode: Add position encoding for decode nodes in multimodal models
|
||||
# to prevent garbled output. Position_ids are transmitted from prefill nodes.
|
||||
if (
|
||||
"position_ids" in request.multimodal_inputs
|
||||
and request.multimodal_inputs["position_ids"] is not None
|
||||
):
|
||||
position_ids = paddle.to_tensor(
|
||||
request.multimodal_inputs["position_ids"],
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
||||
position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]]
|
||||
)[0]
|
||||
else:
|
||||
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
||||
@@ -2709,7 +2723,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
token_type_ids = one["token_type_ids"][np.newaxis, :]
|
||||
token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
|
||||
|
||||
if one["images"] is not None:
|
||||
if "images" in one and one["images"] is not None:
|
||||
image_type_ids = one["image_type_ids"][np.newaxis, :]
|
||||
images = one["images"]
|
||||
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
|
||||
|
||||
Reference in New Issue
Block a user