mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Executor]move batch_id_per_token (#4853)
This commit is contained in:
@@ -1266,6 +1266,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
# NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions
|
||||
self.share_inputs["batch_id_per_token"][:] = -1
|
||||
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
||||
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||
|
||||
@@ -1279,7 +1280,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Initialize forward meta data
|
||||
self.initialize_forward_meta()
|
||||
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
|
||||
|
||||
# Get sampling metadata
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
|
||||
Reference in New Issue
Block a user