[Feature] Multimodal Model P / D Separation (#5323)

* RouterArgs port str -> int

* fix race condition [is_fetching] causing multiple fetch requests

* bugfix: Delete duplicate input_ids tensor creation

* mm pd splitwise json -> pickle5; multimodal_inputs only pos id;
debuglog f to %s

* fix ENABLE_V1_KVCACHE_SCHEDULER=0 mm model lack pos_id, ...

* update cr

* Apply suggestions from code review

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* pre-commit fix

* rm multimodal_inputs deepcopy & fix rdma_cache_transfer.py tpsize=0

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Daci
2025-12-09 10:47:42 +08:00
committed by GitHub
parent a8ffc22032
commit 2f208db4e9
5 changed files with 80 additions and 33 deletions

View File

@@ -826,6 +826,20 @@ class GPUModelRunner(ModelRunnerBase):
dtype="int64",
)
self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
if self.enable_mm:
# Fix for V0 mode: Add position encoding for decode nodes in multimodal models
# to prevent garbled output. Position_ids are transmitted from prefill nodes.
if (
"position_ids" in request.multimodal_inputs
and request.multimodal_inputs["position_ids"] is not None
):
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64",
)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]]
)[0]
else:
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["step_idx"][idx : idx + 1] = 0
@@ -2709,7 +2723,7 @@ class GPUModelRunner(ModelRunnerBase):
token_type_ids = one["token_type_ids"][np.newaxis, :]
token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if one["images"] is not None:
if "images" in one and one["images"] is not None:
image_type_ids = one["image_type_ids"][np.newaxis, :]
images = one["images"]
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)