[Metax] fix GetStopFlagsMulti kernel crash issue (#5556)

This commit is contained in:
MingkunZhang
2025-12-15 17:56:20 +08:00
committed by GitHub
parent 0100ee885f
commit 5265d844e9

View File

@@ -1741,6 +1741,7 @@ class MetaxModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)
@@ -1841,6 +1842,7 @@ class MetaxModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_lens=self.share_inputs["prompt_lens"],
mask_rollback=self.share_inputs["mask_rollback"],
)
@@ -2286,6 +2288,7 @@ class MetaxModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)
@@ -2391,6 +2394,7 @@ class MetaxModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_lens=self.share_inputs["prompt_lens"],
mask_rollback=self.share_inputs["mask_rollback"],
prompt_logprobs_list=prompt_logprobs_list,