[Perf] Remove unnecessary operations in non-cuda_graph (#3010)

* [Perf] Remove unnecessary operations in non-cuda_graph

* fix code logic

* use suggestion comment

* reduce function call

* reduce function call

* reduce function call

* reduce function call
This commit is contained in:
begin2023
2025-07-28 11:38:29 +08:00
committed by GitHub
parent 247010d298
commit dd877f38b1
2 changed files with 15 additions and 9 deletions

View File

@@ -576,8 +576,9 @@ class GCUModelRunner(ModelRunnerBase):
) )
# Update Batch type for cuda graph # Update Batch type for cuda graph
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) self.forward_meta.step_use_cudagraph = self.use_cudagraph and (
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
)
# Initialzie attention meta data # Initialzie attention meta data
for attn_backend in self.attn_backends: for attn_backend in self.attn_backends:

View File

@@ -793,16 +793,21 @@ class GPUModelRunner(ModelRunnerBase):
# Update Batch type for cuda graph # Update Batch type for cuda graph
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) only_decode_batch = True
prefill_exists = None
# mix ep in single node # mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
is_decode_batch_list = [] only_decode_batch_list = []
paddle.distributed.all_gather_object(is_decode_batch_list, is_decode_batch) prefill_exists = self.exist_prefill()
is_decode_batch = all(is_decode_batch_list) paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
self.fd_config.parallel_config.moe_phase.phase = "decode" if is_decode_batch else "prefill" only_decode_batch = all(only_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
)
# Initialzie attention meta data # Initialzie attention meta data
for attn_backend in self.attn_backends: for attn_backend in self.attn_backends: