[Feature] Multimodal Model P / D Separation (#5323)

* RouterArgs port str -> int

* fix race condition [is_fetching] causing multiple fetch requests

* bugfix: Delete duplicate input_ids tensor creation

* mm pd splitwise json -> pickle5; multimodal_inputs only pos id;
debuglog f to %s

* fix ENABLE_V1_KVCACHE_SCHEDULER=0 mm model lack pos_id, ...

* update cr

* Apply suggestions from code review

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* pre-commit fix

* rm multimodal_inputs deepcopy & fix rdma_cache_transfer.py tpsize=0

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Daci
2025-12-09 10:47:42 +08:00
committed by GitHub
parent a8ffc22032
commit 2f208db4e9
5 changed files with 80 additions and 33 deletions

View File

@@ -16,7 +16,6 @@
from __future__ import annotations
import copy
import time
import traceback
from dataclasses import asdict, dataclass, fields
@@ -274,20 +273,6 @@ class Request:
def to_dict(self) -> dict:
"""convert Request into a serializable dict"""
multimodal_inputs = copy.deepcopy(self.multimodal_inputs)
if (
isinstance(multimodal_inputs, dict)
and isinstance(multimodal_inputs.get("mm_positions"), list)
and len(multimodal_inputs["mm_positions"]) > 0
):
# if mm_positions is ImagePosition, convert to dict
try:
for i, mm_pos in enumerate(multimodal_inputs["mm_positions"]):
multimodal_inputs["mm_positions"][i] = (
asdict(mm_pos) if isinstance(mm_pos, ImagePosition) else mm_pos
)
except Exception as e:
data_processor_logger.error(f"Convert ImagePosition to dict error: {e}, {str(traceback.format_exc())}")
data = {
"request_id": self.request_id,
@@ -299,7 +284,6 @@ class Request:
"history": self.history,
"tools": self.tools,
"eos_token_ids": self.eos_token_ids,
"multimodal_inputs": multimodal_inputs,
"multimodal_data": self.multimodal_data,
"disable_chat_template": self.disable_chat_template,
"disaggregate_info": self.disaggregate_info,
@@ -319,6 +303,21 @@ class Request:
"audio_end": self.audio_end,
"ic_req_data": self.ic_req_data,
}
# During multimodal PD separation, position_ids are required
if isinstance(self.multimodal_inputs, dict):
# Optimize multimodal data transfer during PD separation:
# - V1 mode (ENABLE_V1_KVCACHE_SCHEDULER=1): Only position_ids needed for decode nodes
# - V0 mode (ENABLE_V1_KVCACHE_SCHEDULER=0): Full field set required for compatibility
# This filtering significantly reduces serialized data size for large numpy arrays
allowed_keys = {"position_ids"}
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
allowed_keys.update(["input_ids", "token_type_ids", "images", "image_type_ids", "grid_thw"])
data["multimodal_inputs"] = {
key: value for key, value in self.multimodal_inputs.items() if key in allowed_keys
}
add_params = [
"guided_json",
"guided_regex",