[SOT] Remove breakgraph in post processing && fix datatype (#2780)

This commit is contained in:
Ryan
2025-07-10 11:26:00 +08:00
committed by GitHub
parent 2ea267f624
commit b0f525955c
3 changed files with 20 additions and 17 deletions

View File

@@ -929,18 +929,17 @@ class GPUVLModelRunner(VLModelRunnerBase):
False,
) # multi ends
# update inputs
with paddle.framework._no_check_dy2st_diff():
update_inputs(
self.share_inputs["stop_flags"],
self.share_inputs["not_need_stop"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["input_ids"],
self.share_inputs["stop_nums"],
next_tokens,
self.share_inputs["is_block_step"],
)
update_inputs(
self.share_inputs["stop_flags"],
self.share_inputs["not_need_stop"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["input_ids"],
self.share_inputs["stop_nums"],
next_tokens,
self.share_inputs["is_block_step"],
)
save_output(
next_tokens,
self.share_inputs["not_need_stop"],