[Feature] support bad_words (#3055)

* support bad_words

* support online infer bad_words

* update

* add CI test

* update

* update

* update

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
Sunny-bot1
2025-07-30 09:31:29 +08:00
committed by GitHub
parent 9c962343f2
commit 74aa31d15b
10 changed files with 263 additions and 15 deletions

View File

@@ -20,6 +20,8 @@ import random
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union
from fastdeploy.utils import llm_logger as logger
@dataclass
class SamplingParams:
@@ -97,6 +99,7 @@ class SamplingParams:
min_tokens: int = 1
logprobs: Optional[int] = None
bad_words: Optional[List[str]] = None
_bad_words_token_ids: Optional[List[int]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
@@ -201,11 +204,42 @@ class SamplingParams:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
def update_from_tokenizer(self, tokenizer):
"""
# TODO: Implement stop tokens and bad words support
# Currently stop tokens and bad words are not supported yet
"""
pass
"""Support bad words"""
if self.bad_words is None:
return
self._bad_words_token_ids = []
for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"]
if len(prompt_token_ids) != 1:
logger.warning(
f"Skip bad_words: {prompt}."
f"Bad words should be a single token."
f"Got tokens: {prompt_token_ids}."
)
continue
if prompt_token_ids[0] > tokenizer.vocab_size:
logger.warning(
f"Skip bad_words: {prompt}."
f"All token id values should be satisfying:"
f" 0 <= token_id < {tokenizer.vocab_size}."
f"Got token: {prompt_token_ids}."
)
continue
if prompt_token_ids not in self._bad_words_token_ids:
self._bad_words_token_ids.extend(prompt_token_ids)
@property
def bad_words_token_ids(self) -> Optional[List[list[int]]]:
return self._bad_words_token_ids
@dataclass