mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others] update tbo related code (#5485)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -536,6 +536,8 @@ class EPPrefillRunner(EPRunner):
|
||||
)
|
||||
|
||||
def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False):
|
||||
if EPPrefillRunner.allocate_on_comm_stream == allocate_on_comm_stream:
|
||||
return
|
||||
logger.info(
|
||||
f"set allocate_on_comm_stream to {allocate_on_comm_stream}, this will force Prefill dispatch's output tensor is allocated on communication stream"
|
||||
)
|
||||
|
||||
@@ -73,7 +73,12 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta):
|
||||
caches=forward_meta.caches,
|
||||
)
|
||||
|
||||
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]
|
||||
if len(forward_meta.rotary_embs.shape) == 6:
|
||||
max_bs = forward_meta.rotary_embs.shape[0]
|
||||
assert max_bs == forward_meta.block_tables.shape[0]
|
||||
assert forward_meta.rotary_embs.shape[1:3] == [2, 1]
|
||||
assert forward_meta.rotary_embs.shape[4] == 1
|
||||
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]
|
||||
|
||||
res[i].ids_remove_padding = forward_meta.ids_remove_padding[start_token_id:end_token_id]
|
||||
res[i].batch_id_per_token = forward_meta.batch_id_per_token[start_token_id:end_token_id] - start_bs
|
||||
@@ -100,9 +105,10 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta):
|
||||
else:
|
||||
assert False, "Invalid attn_mask_offsets shape"
|
||||
|
||||
# This is to adapt 5
|
||||
# This is adapt 5.0
|
||||
if hasattr(forward_meta, "hidden_states"):
|
||||
res[i].hidden_states = forward_meta.hidden_states[start_token_id:end_token_id]
|
||||
res[i].decode_states = forward_meta.decode_states[start_bs:end_bs]
|
||||
|
||||
res[i].attn_backend.init_attention_metadata(res[i])
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user