[Others] update tbo related code (#5485)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
周周周
2025-12-11 12:34:46 +08:00
committed by GitHub
parent 510b82173a
commit ff353b922f
2 changed files with 10 additions and 2 deletions

View File

@@ -536,6 +536,8 @@ class EPPrefillRunner(EPRunner):
)
def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False):
if EPPrefillRunner.allocate_on_comm_stream == allocate_on_comm_stream:
return
logger.info(
f"set allocate_on_comm_stream to {allocate_on_comm_stream}, this will force Prefill dispatch's output tensor is allocated on communication stream"
)

View File

@@ -73,7 +73,12 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta):
caches=forward_meta.caches,
)
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]
if len(forward_meta.rotary_embs.shape) == 6:
max_bs = forward_meta.rotary_embs.shape[0]
assert max_bs == forward_meta.block_tables.shape[0]
assert forward_meta.rotary_embs.shape[1:3] == [2, 1]
assert forward_meta.rotary_embs.shape[4] == 1
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]
res[i].ids_remove_padding = forward_meta.ids_remove_padding[start_token_id:end_token_id]
res[i].batch_id_per_token = forward_meta.batch_id_per_token[start_token_id:end_token_id] - start_bs
@@ -100,9 +105,10 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta):
else:
assert False, "Invalid attn_mask_offsets shape"
# This is to adapt 5
# This is adapt 5.0
if hasattr(forward_meta, "hidden_states"):
res[i].hidden_states = forward_meta.hidden_states[start_token_id:end_token_id]
res[i].decode_states = forward_meta.decode_states[start_bs:end_bs]
res[i].attn_backend.init_attention_metadata(res[i])
return res