[Executor]move batch_id_per_token (#4853)

This commit is contained in:
周周周
2025-11-14 15:38:48 +08:00
committed by GitHub
parent c0a4393d72
commit 51b1f13547
2 changed files with 9 additions and 9 deletions

View File

@@ -1266,6 +1266,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
# NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions
self.share_inputs["batch_id_per_token"][:] = -1
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -1279,7 +1280,6 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize forward meta data
self.initialize_forward_meta()
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
# Get sampling metadata
self.sampling_metadata = SamplingMetadata(