mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Feature] support bad_words (#3055)
* support bad_words * support online infer bad_words * update * add CI test * update * update * update --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -270,6 +270,14 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
|
||||
if request.get("bad_words_token_ids") is not None:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
if bad_words_len > 0:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
@@ -382,7 +390,8 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||
|
||||
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
@@ -511,6 +520,9 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
# Initialize forward meta data
|
||||
self.initialize_forward_meta()
|
||||
|
||||
@@ -528,7 +540,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
presence_penalties=self.share_inputs["presence_score"],
|
||||
repetition_penalties=self.share_inputs["penalty_score"],
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
)
|
||||
|
Reference in New Issue
Block a user