From ff353b922f2b81ae194e0c45e89ab26b2e097b32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Thu, 11 Dec 2025 12:34:46 +0800 Subject: [PATCH] [Others] update tbo related code (#5485) --- fastdeploy/model_executor/layers/moe/ep.py | 2 ++ fastdeploy/worker/tbo.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index b61fe48f6..4065de51f 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -536,6 +536,8 @@ class EPPrefillRunner(EPRunner): ) def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False): + if EPPrefillRunner.allocate_on_comm_stream == allocate_on_comm_stream: + return logger.info( f"set allocate_on_comm_stream to {allocate_on_comm_stream}, this will force Prefill dispatch's output tensor is allocated on communication stream" ) diff --git a/fastdeploy/worker/tbo.py b/fastdeploy/worker/tbo.py index 856437959..051d0499a 100644 --- a/fastdeploy/worker/tbo.py +++ b/fastdeploy/worker/tbo.py @@ -73,7 +73,12 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta): caches=forward_meta.caches, ) - res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs] + if len(forward_meta.rotary_embs.shape) == 6: + max_bs = forward_meta.rotary_embs.shape[0] + assert max_bs == forward_meta.block_tables.shape[0] + assert forward_meta.rotary_embs.shape[1:3] == [2, 1] + assert forward_meta.rotary_embs.shape[4] == 1 + res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs] res[i].ids_remove_padding = forward_meta.ids_remove_padding[start_token_id:end_token_id] res[i].batch_id_per_token = forward_meta.batch_id_per_token[start_token_id:end_token_id] - start_bs @@ -100,9 +105,10 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta): else: assert False, "Invalid attn_mask_offsets shape" - # This is to adapt 5 + # This is adapt 5.0 if hasattr(forward_meta, "hidden_states"): res[i].hidden_states = forward_meta.hidden_states[start_token_id:end_token_id] res[i].decode_states = forward_meta.decode_states[start_bs:end_bs] + res[i].attn_backend.init_attention_metadata(res[i]) return res