mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-28 13:22:23 +08:00
[Feature] bad words support v1 scheduler and specifiy token ids (#3608)
* support bad_words_token_ids * docs * fix test * fix * bad words support kvcache v1 and token ids * fix
This commit is contained in:
@@ -20,8 +20,6 @@ import random
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from fastdeploy.utils import llm_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
@@ -102,7 +100,7 @@ class SamplingParams:
|
||||
temp_scaled_logprobs: bool = False
|
||||
top_p_normalized_logprobs: bool = False
|
||||
bad_words: Optional[List[str]] = None
|
||||
_bad_words_token_ids: Optional[List[int]] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
|
||||
@@ -134,6 +132,7 @@ class SamplingParams:
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
bad_words=None,
|
||||
bad_words_token_ids=None,
|
||||
) -> SamplingParams:
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(
|
||||
@@ -154,6 +153,7 @@ class SamplingParams:
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
bad_words=bad_words,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -206,46 +206,6 @@ class SamplingParams:
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
def update_from_tokenizer(self, tokenizer):
|
||||
"""Support bad words"""
|
||||
if self.bad_words is None:
|
||||
return
|
||||
self._bad_words_token_ids = []
|
||||
for bad_word in self.bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"]
|
||||
|
||||
if len(prompt_token_ids) != 1:
|
||||
if not add_prefix_space:
|
||||
logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"Bad words should be a single token."
|
||||
f"Got tokens: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids[0] > tokenizer.vocab_size:
|
||||
if not add_prefix_space:
|
||||
logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"All token id values should be satisfying:"
|
||||
f" 0 <= token_id < {tokenizer.vocab_size}."
|
||||
f"Got token: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids not in self._bad_words_token_ids:
|
||||
self._bad_words_token_ids.extend(prompt_token_ids)
|
||||
|
||||
@property
|
||||
def bad_words_token_ids(self) -> Optional[List[list[int]]]:
|
||||
return self._bad_words_token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchParams:
|
||||
|
Reference in New Issue
Block a user