[Feature] bad words support v1 scheduler and specifiy token ids (#3608)

* support bad_words_token_ids

* docs

* fix test

* fix

* bad words support kvcache v1 and token ids

* fix
This commit is contained in:
Sunny-bot1
2025-08-26 11:14:51 +08:00
committed by GitHub
parent c43a4bec00
commit c68c3c4b8b
16 changed files with 420 additions and 62 deletions

View File

@@ -183,7 +183,7 @@ Used to prevent the model from generating certain specific words during the infe
## Usage Instructions ## Usage Instructions
Include the `bad_words` parameter in the request: Include the `bad_words` or `bad_words_token_ids` parameter in the request:
* Example request with curl: * Example request with curl:
@@ -192,9 +192,22 @@ curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"messages": [ "messages": [
{"role": "user", "content": "How old are you"} {"role": "user", "content": "How are you"}
], ],
"bad_words": ["age", "I"] "bad_words": [" well", " Today"]
}'
```
Equal to
```bash
curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "How are you"}
],
"bad_words_token_ids": [1622, 25062]
}' }'
``` ```
@@ -203,15 +216,37 @@ curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
```python ```python
import openai import openai
host = "0.0.0.0" host = "0.0.0.0"
port = "8170" port = "9222"
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null") client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
response = client.chat.completions.create( response = client.chat.completions.create(
model="null", model="null",
messages=[ messages=[
{"role": "system", "content": "I'm a helpful AI assistant."}, {"role": "user", "content": "Hello, how are you?"},
], ],
extra_body={"bad_words": ["you", "me"]}, extra_body={"bad_words": [" well", " Today"]},
stream=True,
)
for chunk in response:
if chunk.choices[0].delta:
print(chunk.choices[0].delta.content, end='')
print('\n')
```
Equal to
```python
import openai
host = "0.0.0.0"
port = "9222"
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
response = client.chat.completions.create(
model="null",
messages=[
{"role": "user", "content": "Hello, how are you?"},
],
extra_body={"bad_words_token_ids": [1622, 25062]},
stream=True, stream=True,
) )
for chunk in response: for chunk in response:
@@ -223,3 +258,5 @@ print('\n')
## Parameter Description ## Parameter Description
`bad_words`: List of forbidden words. Type: list of str. Each word must be a single token. `bad_words`: List of forbidden words. Type: list of str. Each word must be a single token.
`bad_words_token_ids`: List of forbidden token ids. Type: list of int.

View File

@@ -153,6 +153,9 @@ include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None bad_words: Optional[List[str]] = None
# List of forbidden words (e.g., sensitive words) that the model should avoid generating (default None means no restriction). # List of forbidden words (e.g., sensitive words) that the model should avoid generating (default None means no restriction).
bad_words_token_ids: Optional[List[int]] = None
# List of forbidden token ids that the model should avoid generating (default None means no restriction).
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
# Repetition penalty coefficient, reducing the probability of repeating already generated tokens (`>1.0` suppresses repetition, `<1.0` encourages repetition, default None means disabled). # Repetition penalty coefficient, reducing the probability of repeating already generated tokens (`>1.0` suppresses repetition, `<1.0` encourages repetition, default None means disabled).
``` ```
@@ -340,6 +343,9 @@ include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None bad_words: Optional[List[str]] = None
# List of forbidden words (e.g., sensitive words) that the model should avoid generating (default None means no restriction). # List of forbidden words (e.g., sensitive words) that the model should avoid generating (default None means no restriction).
bad_words_token_ids: Optional[List[int]] = None
# List of forbidden token ids that the model should avoid generating (default None means no restriction).
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
# Repetition penalty coefficient, reducing the probability of repeating already generated tokens (`>1.0` suppresses repetition, `<1.0` encourages repetition, default None means disabled). # Repetition penalty coefficient, reducing the probability of repeating already generated tokens (`>1.0` suppresses repetition, `<1.0` encourages repetition, default None means disabled).
``` ```

View File

@@ -183,7 +183,7 @@ print('\n')
## 使用说明 ## 使用说明
请求中加入bad_words参数 可以在请求中加入bad_words参数也可以加入bad_words_token_ids参数
* 使用 curl 命令发送用户请求示例如下: * 使用 curl 命令发送用户请求示例如下:
@@ -192,9 +192,22 @@ curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"messages": [ "messages": [
{"role": "user", "content": "How old are you"} {"role": "user", "content": "How are you"}
], ],
"bad_words": ["age", "I"] "bad_words": [" well", " Today"]
}'
```
等价于
```bash
curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "How are you"}
],
"bad_words_token_ids": [1622, 25062]
}' }'
``` ```
@@ -203,15 +216,37 @@ curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \
```python ```python
import openai import openai
host = "0.0.0.0" host = "0.0.0.0"
port = "8170" port = "9222"
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null") client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
response = client.chat.completions.create( response = client.chat.completions.create(
model="null", model="null",
messages=[ messages=[
{"role": "system", "content": "I'm a helpful AI assistant."}, {"role": "user", "content": "Hello, how are you?"},
], ],
extra_body={"bad_words": ["you", "me"]}, extra_body={"bad_words": [" well", " Today"]},
stream=True,
)
for chunk in response:
if chunk.choices[0].delta:
print(chunk.choices[0].delta.content, end='')
print('\n')
```
等价于
```python
import openai
host = "0.0.0.0"
port = "9222"
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
response = client.chat.completions.create(
model="null",
messages=[
{"role": "user", "content": "Hello, how are you?"},
],
extra_body={"bad_words_token_ids": [1622, 25062]},
stream=True, stream=True,
) )
for chunk in response: for chunk in response:
@@ -223,3 +258,4 @@ print('\n')
## 参数说明 ## 参数说明
* `bad_words`: 禁止生成的词列表。list类型每个元素为str类型。仅支持每个元素为单个token。 * `bad_words`: 禁止生成的词列表。list类型每个元素为str类型。仅支持每个元素为单个token。
* `bad_words_token_ids`: 禁止生成的token id列表。list类型每个元素为int类型。

View File

@@ -153,6 +153,9 @@ include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None bad_words: Optional[List[str]] = None
# 禁止生成的词汇列表(例如敏感词),模型会避免输出这些词(默认 None 表示不限制)。 # 禁止生成的词汇列表(例如敏感词),模型会避免输出这些词(默认 None 表示不限制)。
bad_words_token_ids: Optional[List[int]] = None
# 禁止生成的token id列表模型会避免输出这些词默认 None 表示不限制)。
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
# 重复惩罚系数,降低已生成 token 的重复概率(>1.0 抑制重复,<1.0 鼓励重复,默认 None 表示禁用)。 # 重复惩罚系数,降低已生成 token 的重复概率(>1.0 抑制重复,<1.0 鼓励重复,默认 None 表示禁用)。
``` ```

View File

@@ -461,7 +461,6 @@ class LLMEngine:
request = Request.from_dict(task) request = Request.from_dict(task)
llm_logger.info(f"Receive request {request}") llm_logger.info(f"Receive request {request}")
if sampling_params is not None: if sampling_params is not None:
sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
request.sampling_params = sampling_params request.sampling_params = sampling_params
request.preprocess_start_time = time.time() request.preprocess_start_time = time.time()
@@ -762,8 +761,6 @@ class LLMEngine:
for task in tasks: for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
self.resource_manager.check_and_free_block_tables() self.resource_manager.check_and_free_block_tables()

View File

@@ -20,8 +20,6 @@ import random
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
from fastdeploy.utils import llm_logger as logger
@dataclass @dataclass
class SamplingParams: class SamplingParams:
@@ -102,7 +100,7 @@ class SamplingParams:
temp_scaled_logprobs: bool = False temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False top_p_normalized_logprobs: bool = False
bad_words: Optional[List[str]] = None bad_words: Optional[List[str]] = None
_bad_words_token_ids: Optional[List[int]] = None bad_words_token_ids: Optional[List[int]] = None
@classmethod @classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams: def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
@@ -134,6 +132,7 @@ class SamplingParams:
min_tokens=1, min_tokens=1,
logprobs=None, logprobs=None,
bad_words=None, bad_words=None,
bad_words_token_ids=None,
) -> SamplingParams: ) -> SamplingParams:
"""Create instance from command line arguments""" """Create instance from command line arguments"""
return cls( return cls(
@@ -154,6 +153,7 @@ class SamplingParams:
min_tokens=min_tokens, min_tokens=min_tokens,
logprobs=logprobs, logprobs=logprobs,
bad_words=bad_words, bad_words=bad_words,
bad_words_token_ids=bad_words_token_ids,
) )
def __post_init__(self): def __post_init__(self):
@@ -206,46 +206,6 @@ class SamplingParams:
if not 0 <= self.seed <= 922337203685477580: if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.") raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
def update_from_tokenizer(self, tokenizer):
"""Support bad words"""
if self.bad_words is None:
return
self._bad_words_token_ids = []
for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"]
if len(prompt_token_ids) != 1:
if not add_prefix_space:
logger.warning(
f"Skip bad_words: <{prompt}>."
f"Bad words should be a single token."
f"Got tokens: {prompt_token_ids}."
)
continue
if prompt_token_ids[0] > tokenizer.vocab_size:
if not add_prefix_space:
logger.warning(
f"Skip bad_words: <{prompt}>."
f"All token id values should be satisfying:"
f" 0 <= token_id < {tokenizer.vocab_size}."
f"Got token: {prompt_token_ids}."
)
continue
if prompt_token_ids not in self._bad_words_token_ids:
self._bad_words_token_ids.extend(prompt_token_ids)
@property
def bad_words_token_ids(self) -> Optional[List[list[int]]]:
return self._bad_words_token_ids
@dataclass @dataclass
class BeamSearchParams: class BeamSearchParams:

View File

@@ -426,6 +426,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = None min_tokens: Optional[int] = None
include_stop_str_in_output: Optional[bool] = False include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None bad_words: Optional[List[str]] = None
bad_words_token_ids: Optional[List[int]] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: start-completion-extra-params # doc: start-completion-extra-params
@@ -566,6 +567,7 @@ class ChatCompletionRequest(BaseModel):
min_tokens: Optional[int] = None min_tokens: Optional[int] = None
include_stop_str_in_output: Optional[bool] = False include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None bad_words: Optional[List[str]] = None
bad_words_token_ids: Optional[List[int]] = None
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
# doc: end-chat-completion-sampling-params # doc: end-chat-completion-sampling-params

View File

@@ -97,6 +97,12 @@ class ErnieProcessor(BaseDataProcessor):
request.set("stop_token_ids", stop_seqs) request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len) request.set("stop_seqs_len", stop_seqs_len)
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is None and request.messages is None: if request.prompt is None and request.messages is None:
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.") raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
@@ -160,6 +166,13 @@ class ErnieProcessor(BaseDataProcessor):
request["stop_token_ids"] = stop_seqs request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len request["stop_seqs_len"] = stop_seqs_len
# processing bad_words
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
# processing prompt_token_ids # processing prompt_token_ids
if not request.get("prompt_token_ids"): if not request.get("prompt_token_ids"):
if request.get("prompt") is None and request.get("messages") is None: if request.get("prompt") is None and request.get("messages") is None:
@@ -481,3 +494,43 @@ class ErnieProcessor(BaseDataProcessor):
def process_logprob_response(self, token_ids, **kwargs): def process_logprob_response(self, token_ids, **kwargs):
full_text = self.tokenizer.decode(token_ids, **kwargs) full_text = self.tokenizer.decode(token_ids, **kwargs)
return full_text return full_text
def update_bad_words(self, bad_words, bad_words_token_ids):
"""Support bad words"""
token_ids = bad_words_token_ids
if token_ids is None:
token_ids = []
for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
data_processor_logger.debug(f"processed bad_words: {prompt}, {prompt_token_ids}")
if len(prompt_token_ids) != 1:
if not add_prefix_space:
data_processor_logger.warning(
f"Skip bad_words: <{prompt}>."
f"Bad words should be a single token."
f"Got tokens: {prompt_token_ids}."
)
continue
if prompt_token_ids[0] > self.tokenizer.vocab_size:
if not add_prefix_space:
data_processor_logger.warning(
f"Skip bad_words: <{prompt}>."
f"All token id values should be satisfying:"
f" 0 <= token_id < {self.tokenizer.vocab_size}."
f"Got token: {prompt_token_ids}."
)
continue
if prompt_token_ids not in token_ids:
token_ids.extend(prompt_token_ids)
return token_ids

View File

@@ -208,6 +208,12 @@ class ErnieMoEVLProcessor(ErnieProcessor):
request["stop_token_ids"] = stop_seqs request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len request["stop_seqs_len"] = stop_seqs_len
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.get("prompt"): if request.get("prompt"):
multimodal_data = request.get("multimodal_data") multimodal_data = request.get("multimodal_data")
if multimodal_data is None: if multimodal_data is None:

View File

@@ -212,6 +212,12 @@ class QwenVLProcessor(TextProcessor):
request["stop_token_ids"] = stop_seqs request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len request["stop_seqs_len"] = stop_seqs_len
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.get("prompt"): if request.get("prompt"):
multimodal_data = request.get("multimodal_data") multimodal_data = request.get("multimodal_data")
if multimodal_data is None: if multimodal_data is None:

View File

@@ -214,6 +214,12 @@ class DataProcessor(BaseDataProcessor):
request.set("stop_token_ids", stop_seqs) request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len) request.set("stop_seqs_len", stop_seqs_len)
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None: if request.prompt is not None:
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len) request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
@@ -270,6 +276,13 @@ class DataProcessor(BaseDataProcessor):
request["stop_token_ids"] = stop_seqs request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len request["stop_seqs_len"] = stop_seqs_len
# processing bad_words
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
data_processor_logger.info(f"Processing request {request}") data_processor_logger.info(f"Processing request {request}")
# processing prompt_token_ids # processing prompt_token_ids
if not request.get("prompt_token_ids"): if not request.get("prompt_token_ids"):
@@ -652,3 +665,42 @@ class DataProcessor(BaseDataProcessor):
stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False) stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False)
data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}") data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
return stop_seqs, stop_seqs_len return stop_seqs, stop_seqs_len
def update_bad_words(self, bad_words, bad_words_token_ids):
"""Support bad words"""
token_ids = bad_words_token_ids
if token_ids is None:
token_ids = []
for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
if len(prompt_token_ids) != 1:
if not add_prefix_space:
data_processor_logger.warning(
f"Skip bad_words: <{prompt}>."
f"Bad words should be a single token."
f"Got tokens: {prompt_token_ids}."
)
continue
if prompt_token_ids[0] > self.tokenizer.vocab_size:
if not add_prefix_space:
data_processor_logger.warning(
f"Skip bad_words: <{prompt}>."
f"All token id values should be satisfying:"
f" 0 <= token_id < {self.tokenizer.vocab_size}."
f"Got token: {prompt_token_ids}."
)
continue
if prompt_token_ids not in token_ids:
token_ids.extend(prompt_token_ids)
return token_ids

View File

@@ -339,6 +339,16 @@ class GPUModelRunner(ModelRunnerBase):
if request.get("seed") is not None: if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids"))
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
else:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):

View File

@@ -322,6 +322,16 @@ class MetaxModelRunner(ModelRunnerBase):
if request.get("seed") is not None: if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids"))
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
else:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):

View File

@@ -455,6 +455,16 @@ class XPUModelRunner(ModelRunnerBase):
if request.get("seed") is not None: if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids"))
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
else:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):

View File

@@ -847,7 +847,24 @@ def test_non_streaming_chat_with_bad_words(openai_client, capsys):
assert hasattr(response_1.choices[0], "message") assert hasattr(response_1.choices[0], "message")
assert hasattr(response_1.choices[0].message, "completion_token_ids") assert hasattr(response_1.choices[0].message, "completion_token_ids")
assert isinstance(response_1.choices[0].message.completion_token_ids, list) assert isinstance(response_1.choices[0].message.completion_token_ids, list)
response_2 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
top_p=0.0,
max_tokens=20,
extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True},
stream=False,
)
assert hasattr(response_2, "choices")
assert len(response_2.choices) > 0
assert hasattr(response_2.choices[0], "message")
assert hasattr(response_2.choices[0].message, "completion_token_ids")
assert isinstance(response_2.choices[0].message.completion_token_ids, list)
assert not any(ids in response_1.choices[0].message.completion_token_ids for ids in bad_token_ids) assert not any(ids in response_1.choices[0].message.completion_token_ids for ids in bad_token_ids)
assert not any(ids in response_2.choices[0].message.completion_token_ids for ids in bad_token_ids)
def test_streaming_chat_with_bad_words(openai_client, capsys): def test_streaming_chat_with_bad_words(openai_client, capsys):
@@ -906,7 +923,34 @@ def test_streaming_chat_with_bad_words(openai_client, capsys):
assert isinstance(chunk.choices[0].delta.completion_token_ids, list) assert isinstance(chunk.choices[0].delta.completion_token_ids, list)
output_tokens_1.append(chunk.choices[0].delta.content) output_tokens_1.append(chunk.choices[0].delta.content)
output_ids_1.extend(chunk.choices[0].delta.completion_token_ids) output_ids_1.extend(chunk.choices[0].delta.completion_token_ids)
response_2 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
top_p=0.0,
max_tokens=20,
extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True},
stream=True,
)
output_tokens_2 = []
output_ids_2 = []
is_first_chunk = True
for chunk in response_2:
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "delta")
assert hasattr(chunk.choices[0].delta, "content")
assert hasattr(chunk.choices[0].delta, "completion_token_ids")
if is_first_chunk:
is_first_chunk = False
else:
assert isinstance(chunk.choices[0].delta.completion_token_ids, list)
output_tokens_2.append(chunk.choices[0].delta.content)
output_ids_2.extend(chunk.choices[0].delta.completion_token_ids)
assert not any(ids in output_ids_1 for ids in bad_token_ids) assert not any(ids in output_ids_1 for ids in bad_token_ids)
assert not any(ids in output_ids_2 for ids in bad_token_ids)
def test_non_streaming_completion_with_bad_words(openai_client, capsys): def test_non_streaming_completion_with_bad_words(openai_client, capsys):
@@ -956,9 +1000,25 @@ def test_non_streaming_completion_with_bad_words(openai_client, capsys):
) )
assert hasattr(response_1, "choices") assert hasattr(response_1, "choices")
assert len(response_1.choices) > 0 assert len(response_1.choices) > 0
assert hasattr(response_0.choices[0], "completion_token_ids") assert hasattr(response_1.choices[0], "completion_token_ids")
assert isinstance(response_0.choices[0].completion_token_ids, list) assert isinstance(response_1.choices[0].completion_token_ids, list)
response_2 = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
top_p=0.0,
max_tokens=20,
extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True},
stream=False,
)
assert hasattr(response_2, "choices")
assert len(response_2.choices) > 0
assert hasattr(response_2.choices[0], "completion_token_ids")
assert isinstance(response_2.choices[0].completion_token_ids, list)
assert not any(ids in response_1.choices[0].completion_token_ids for ids in bad_token_ids) assert not any(ids in response_1.choices[0].completion_token_ids for ids in bad_token_ids)
assert not any(ids in response_2.choices[0].completion_token_ids for ids in bad_token_ids)
def test_streaming_completion_with_bad_words(openai_client, capsys): def test_streaming_completion_with_bad_words(openai_client, capsys):
@@ -1013,7 +1073,32 @@ def test_streaming_completion_with_bad_words(openai_client, capsys):
assert hasattr(chunk.choices[0], "completion_token_ids") assert hasattr(chunk.choices[0], "completion_token_ids")
output_tokens_1.append(chunk.choices[0].text) output_tokens_1.append(chunk.choices[0].text)
output_ids_1.extend(chunk.choices[0].completion_token_ids) output_ids_1.extend(chunk.choices[0].completion_token_ids)
# add bad words token ids
response_2 = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
top_p=0.0,
max_tokens=20,
extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True},
stream=True,
)
output_tokens_2 = []
output_ids_2 = []
is_first_chunk = True
for chunk in response_2:
if is_first_chunk:
is_first_chunk = False
else:
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "text")
assert hasattr(chunk.choices[0], "completion_token_ids")
output_tokens_2.append(chunk.choices[0].text)
output_ids_2.extend(chunk.choices[0].completion_token_ids)
assert not any(ids in output_ids_1 for ids in bad_token_ids) assert not any(ids in output_ids_1 for ids in bad_token_ids)
assert not any(ids in output_ids_2 for ids in bad_token_ids)
def test_profile_reset_block_num(): def test_profile_reset_block_num():

View File

@@ -842,7 +842,24 @@ def test_non_streaming_chat_with_bad_words(openai_client, capsys):
assert hasattr(response_1.choices[0], "message") assert hasattr(response_1.choices[0], "message")
assert hasattr(response_1.choices[0].message, "completion_token_ids") assert hasattr(response_1.choices[0].message, "completion_token_ids")
assert isinstance(response_1.choices[0].message.completion_token_ids, list) assert isinstance(response_1.choices[0].message.completion_token_ids, list)
response_2 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
top_p=0.0,
max_tokens=20,
extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True},
stream=False,
)
assert hasattr(response_2, "choices")
assert len(response_2.choices) > 0
assert hasattr(response_2.choices[0], "message")
assert hasattr(response_2.choices[0].message, "completion_token_ids")
assert isinstance(response_2.choices[0].message.completion_token_ids, list)
assert not any(ids in response_1.choices[0].message.completion_token_ids for ids in bad_token_ids) assert not any(ids in response_1.choices[0].message.completion_token_ids for ids in bad_token_ids)
assert not any(ids in response_2.choices[0].message.completion_token_ids for ids in bad_token_ids)
def test_streaming_chat_with_bad_words(openai_client, capsys): def test_streaming_chat_with_bad_words(openai_client, capsys):
@@ -901,7 +918,34 @@ def test_streaming_chat_with_bad_words(openai_client, capsys):
assert isinstance(chunk.choices[0].delta.completion_token_ids, list) assert isinstance(chunk.choices[0].delta.completion_token_ids, list)
output_tokens_1.append(chunk.choices[0].delta.content) output_tokens_1.append(chunk.choices[0].delta.content)
output_ids_1.extend(chunk.choices[0].delta.completion_token_ids) output_ids_1.extend(chunk.choices[0].delta.completion_token_ids)
response_2 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
top_p=0.0,
max_tokens=20,
extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True},
stream=True,
)
output_tokens_2 = []
output_ids_2 = []
is_first_chunk = True
for chunk in response_2:
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "delta")
assert hasattr(chunk.choices[0].delta, "content")
assert hasattr(chunk.choices[0].delta, "completion_token_ids")
if is_first_chunk:
is_first_chunk = False
else:
assert isinstance(chunk.choices[0].delta.completion_token_ids, list)
output_tokens_2.append(chunk.choices[0].delta.content)
output_ids_2.extend(chunk.choices[0].delta.completion_token_ids)
assert not any(ids in output_ids_1 for ids in bad_token_ids) assert not any(ids in output_ids_1 for ids in bad_token_ids)
assert not any(ids in output_ids_2 for ids in bad_token_ids)
def test_non_streaming_completion_with_bad_words(openai_client, capsys): def test_non_streaming_completion_with_bad_words(openai_client, capsys):
@@ -951,9 +995,25 @@ def test_non_streaming_completion_with_bad_words(openai_client, capsys):
) )
assert hasattr(response_1, "choices") assert hasattr(response_1, "choices")
assert len(response_1.choices) > 0 assert len(response_1.choices) > 0
assert hasattr(response_0.choices[0], "completion_token_ids") assert hasattr(response_1.choices[0], "completion_token_ids")
assert isinstance(response_0.choices[0].completion_token_ids, list) assert isinstance(response_1.choices[0].completion_token_ids, list)
response_2 = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
top_p=0.0,
max_tokens=20,
extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True},
stream=False,
)
assert hasattr(response_2, "choices")
assert len(response_2.choices) > 0
assert hasattr(response_2.choices[0], "completion_token_ids")
assert isinstance(response_2.choices[0].completion_token_ids, list)
assert not any(ids in response_1.choices[0].completion_token_ids for ids in bad_token_ids) assert not any(ids in response_1.choices[0].completion_token_ids for ids in bad_token_ids)
assert not any(ids in response_2.choices[0].completion_token_ids for ids in bad_token_ids)
def test_streaming_completion_with_bad_words(openai_client, capsys): def test_streaming_completion_with_bad_words(openai_client, capsys):
@@ -1008,7 +1068,32 @@ def test_streaming_completion_with_bad_words(openai_client, capsys):
assert hasattr(chunk.choices[0], "completion_token_ids") assert hasattr(chunk.choices[0], "completion_token_ids")
output_tokens_1.append(chunk.choices[0].text) output_tokens_1.append(chunk.choices[0].text)
output_ids_1.extend(chunk.choices[0].completion_token_ids) output_ids_1.extend(chunk.choices[0].completion_token_ids)
# add bad words token ids
response_2 = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
top_p=0.0,
max_tokens=20,
extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True},
stream=True,
)
output_tokens_2 = []
output_ids_2 = []
is_first_chunk = True
for chunk in response_2:
if is_first_chunk:
is_first_chunk = False
else:
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "text")
assert hasattr(chunk.choices[0], "completion_token_ids")
output_tokens_2.append(chunk.choices[0].text)
output_ids_2.extend(chunk.choices[0].completion_token_ids)
assert not any(ids in output_ids_1 for ids in bad_token_ids) assert not any(ids in output_ids_1 for ids in bad_token_ids)
assert not any(ids in output_ids_2 for ids in bad_token_ids)
def test_profile_reset_block_num(): def test_profile_reset_block_num():