mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] bad words support v1 scheduler and specifiy token ids (#3608)
* support bad_words_token_ids * docs * fix test * fix * bad words support kvcache v1 and token ids * fix
This commit is contained in:
@@ -214,6 +214,12 @@ class DataProcessor(BaseDataProcessor):
|
||||
request.set("stop_token_ids", stop_seqs)
|
||||
request.set("stop_seqs_len", stop_seqs_len)
|
||||
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
|
||||
@@ -270,6 +276,13 @@ class DataProcessor(BaseDataProcessor):
|
||||
request["stop_token_ids"] = stop_seqs
|
||||
request["stop_seqs_len"] = stop_seqs_len
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
data_processor_logger.info(f"Processing request {request}")
|
||||
# processing prompt_token_ids
|
||||
if not request.get("prompt_token_ids"):
|
||||
@@ -652,3 +665,42 @@ class DataProcessor(BaseDataProcessor):
|
||||
stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False)
|
||||
data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
|
||||
return stop_seqs, stop_seqs_len
|
||||
|
||||
def update_bad_words(self, bad_words, bad_words_token_ids):
|
||||
"""Support bad words"""
|
||||
|
||||
token_ids = bad_words_token_ids
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = []
|
||||
for bad_word in bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
|
||||
|
||||
if len(prompt_token_ids) != 1:
|
||||
if not add_prefix_space:
|
||||
data_processor_logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"Bad words should be a single token."
|
||||
f"Got tokens: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids[0] > self.tokenizer.vocab_size:
|
||||
if not add_prefix_space:
|
||||
data_processor_logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"All token id values should be satisfying:"
|
||||
f" 0 <= token_id < {self.tokenizer.vocab_size}."
|
||||
f"Got token: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids not in token_ids:
|
||||
token_ids.extend(prompt_token_ids)
|
||||
return token_ids
|
||||
|
Reference in New Issue
Block a user