mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
[Feature] bad words support v1 scheduler and specifiy token ids (#3608)
* support bad_words_token_ids * docs * fix test * fix * bad words support kvcache v1 and token ids * fix
This commit is contained in:
@@ -339,6 +339,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if request.get("seed") is not None:
|
||||
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
|
||||
|
||||
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
else:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
|
Reference in New Issue
Block a user