[Feature] support bad_words (#3055)

* support bad_words

* support online infer bad_words

* update

* add CI test

* update

* update

* update

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
Sunny-bot1
2025-07-30 09:31:29 +08:00
committed by GitHub
parent 9c962343f2
commit 74aa31d15b
10 changed files with 263 additions and 15 deletions

View File

@@ -506,6 +506,14 @@ class XPUModelRunner(ModelRunnerBase):
request.block_tables, dtype="int32"
)
if request.get("bad_words_token_ids") is not None:
bad_words_len = len(request.get("bad_words_token_ids"))
if bad_words_len > 0:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -574,7 +582,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -652,6 +661,9 @@ class XPUModelRunner(ModelRunnerBase):
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
)
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend()
@@ -667,7 +679,7 @@ class XPUModelRunner(ModelRunnerBase):
presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"],
min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"],
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"],
)