[Feature] support bad_words (#3055)

* support bad_words

* support online infer bad_words

* update

* add CI test

* update

* update

* update

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
Sunny-bot1
2025-07-30 09:31:29 +08:00
committed by GitHub
parent 9c962343f2
commit 74aa31d15b
10 changed files with 263 additions and 15 deletions

View File

@@ -180,7 +180,7 @@ void token_penalty_multi_scores_kernel(
int64_t token_num = shape[0]; int64_t token_num = shape[0];
int64_t length = shape[1]; int64_t length = shape[1];
int64_t length_id = pre_ids.shape()[1]; int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0]; int64_t length_bad_words = bad_tokens.shape()[1];
int64_t end_length = eos_token_id.shape()[0]; int64_t end_length = eos_token_id.shape()[0];

View File

@@ -171,7 +171,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
int64_t vocab_size = shape[1]; int64_t vocab_size = shape[1];
int64_t max_dec_len = pre_ids.shape()[1]; int64_t max_dec_len = pre_ids.shape()[1];
int64_t bad_words_len = bad_tokens.shape()[0]; int64_t bad_words_len = bad_tokens.shape()[1];
int64_t eos_len = eos_token_id.shape()[0]; int64_t eos_len = eos_token_id.shape()[0];
int64_t max_model_len = prompt_ids.shape()[1]; int64_t max_model_len = prompt_ids.shape()[1];

View File

@@ -491,6 +491,7 @@ class LLMEngine:
request = Request.from_dict(task) request = Request.from_dict(task)
llm_logger.info(f"Receive request {request}") llm_logger.info(f"Receive request {request}")
if sampling_params is not None: if sampling_params is not None:
sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
request.sampling_params = sampling_params request.sampling_params = sampling_params
request.preprocess_start_time = time.time() request.preprocess_start_time = time.time()
@@ -747,6 +748,8 @@ class LLMEngine:
""" """
for task in tasks: for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
# TODO 返回至 scheduler # TODO 返回至 scheduler
if allocated: if allocated:
current_tasks = [] current_tasks = []

View File

@@ -20,6 +20,8 @@ import random
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
from fastdeploy.utils import llm_logger as logger
@dataclass @dataclass
class SamplingParams: class SamplingParams:
@@ -97,6 +99,7 @@ class SamplingParams:
min_tokens: int = 1 min_tokens: int = 1
logprobs: Optional[int] = None logprobs: Optional[int] = None
bad_words: Optional[List[str]] = None bad_words: Optional[List[str]] = None
_bad_words_token_ids: Optional[List[int]] = None
@classmethod @classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams: def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
@@ -201,11 +204,42 @@ class SamplingParams:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.") raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
def update_from_tokenizer(self, tokenizer): def update_from_tokenizer(self, tokenizer):
""" """Support bad words"""
# TODO: Implement stop tokens and bad words support if self.bad_words is None:
# Currently stop tokens and bad words are not supported yet return
""" self._bad_words_token_ids = []
pass for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"]
if len(prompt_token_ids) != 1:
logger.warning(
f"Skip bad_words: {prompt}."
f"Bad words should be a single token."
f"Got tokens: {prompt_token_ids}."
)
continue
if prompt_token_ids[0] > tokenizer.vocab_size:
logger.warning(
f"Skip bad_words: {prompt}."
f"All token id values should be satisfying:"
f" 0 <= token_id < {tokenizer.vocab_size}."
f"Got token: {prompt_token_ids}."
)
continue
if prompt_token_ids not in self._bad_words_token_ids:
self._bad_words_token_ids.extend(prompt_token_ids)
@property
def bad_words_token_ids(self) -> Optional[List[list[int]]]:
return self._bad_words_token_ids
@dataclass @dataclass

View File

@@ -349,6 +349,7 @@ class CompletionRequest(BaseModel):
extra_body: Optional[dict] = None extra_body: Optional[dict] = None
return_token_ids: Optional[bool] = False return_token_ids: Optional[bool] = False
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None
bad_words: Optional[List[str]] = None
response_format: Optional[AnyResponseFormat] = None response_format: Optional[AnyResponseFormat] = None
guided_json: Optional[Union[str, dict, BaseModel]] = None guided_json: Optional[Union[str, dict, BaseModel]] = None
@@ -484,6 +485,7 @@ class ChatCompletionRequest(BaseModel):
return_token_ids: Optional[bool] = False return_token_ids: Optional[bool] = False
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None
disable_chat_template: Optional[bool] = False disable_chat_template: Optional[bool] = False
bad_words: Optional[List[str]] = None
response_format: Optional[AnyResponseFormat] = None response_format: Optional[AnyResponseFormat] = None
guided_json: Optional[Union[str, dict, BaseModel]] = None guided_json: Optional[Union[str, dict, BaseModel]] = None

View File

@@ -270,6 +270,14 @@ class GCUModelRunner(ModelRunnerBase):
request.block_tables, dtype="int32" request.block_tables, dtype="int32"
) )
if request.get("bad_words_token_ids") is not None:
bad_words_len = len(request.get("bad_words_token_ids"))
if bad_words_len > 0:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -382,7 +390,8 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool") self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64") self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64") self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool") self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32") self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -511,6 +520,9 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False) self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
# Initialize forward meta data # Initialize forward meta data
self.initialize_forward_meta() self.initialize_forward_meta()
@@ -528,7 +540,7 @@ class GCUModelRunner(ModelRunnerBase):
presence_penalties=self.share_inputs["presence_score"], presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"], repetition_penalties=self.share_inputs["penalty_score"],
min_dec_lens=self.share_inputs["min_dec_len"], min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"], eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None, max_num_logprobs=20 if self.enable_logprob else None,
) )

View File

@@ -448,6 +448,14 @@ class GPUModelRunner(ModelRunnerBase):
request.block_tables, dtype="int32" request.block_tables, dtype="int32"
) )
if request.get("bad_words_token_ids") is not None:
bad_words_len = len(request.get("bad_words_token_ids"))
if bad_words_len > 0:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -567,7 +575,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool") self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64") self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64") self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool") self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32") self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -733,6 +742,9 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False) self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
# Initialize forward meta data # Initialize forward meta data
self.initialize_forward_meta() self.initialize_forward_meta()
@@ -750,7 +762,7 @@ class GPUModelRunner(ModelRunnerBase):
presence_penalties=self.share_inputs["presence_score"], presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"], repetition_penalties=self.share_inputs["penalty_score"],
min_dec_lens=self.share_inputs["min_dec_len"], min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"], eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None, max_num_logprobs=20 if self.enable_logprob else None,
enable_early_stop=self.enable_early_stop, enable_early_stop=self.enable_early_stop,

View File

@@ -242,6 +242,14 @@ class IluvatarModelRunner(ModelRunnerBase):
request.block_tables, dtype="int32" request.block_tables, dtype="int32"
) )
if request.get("bad_words_token_ids") is not None:
bad_words_len = len(request.get("bad_words_token_ids"))
if bad_words_len > 0:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -347,7 +355,8 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool") self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64") self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64") self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool") self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32") self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -484,6 +493,9 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False) self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
# Initialize forward meta data # Initialize forward meta data
self.initialize_forward_meta() self.initialize_forward_meta()
@@ -500,7 +512,7 @@ class IluvatarModelRunner(ModelRunnerBase):
presence_penalties=self.share_inputs["presence_score"], presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"], repetition_penalties=self.share_inputs["penalty_score"],
min_dec_lens=self.share_inputs["min_dec_len"], min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"], eos_token_ids=self.share_inputs["eos_token_id"],
) )

View File

@@ -506,6 +506,14 @@ class XPUModelRunner(ModelRunnerBase):
request.block_tables, dtype="int32" request.block_tables, dtype="int32"
) )
if request.get("bad_words_token_ids") is not None:
bad_words_len = len(request.get("bad_words_token_ids"))
if bad_words_len > 0:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -574,7 +582,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool") self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64") self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64") self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool") self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32") self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -652,6 +661,9 @@ class XPUModelRunner(ModelRunnerBase):
seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
) )
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
self.forward_meta.attn_backend = self.attn_backends[0] self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend() self.initialize_attention_backend()
@@ -667,7 +679,7 @@ class XPUModelRunner(ModelRunnerBase):
presence_penalties=self.share_inputs["presence_score"], presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"], repetition_penalties=self.share_inputs["penalty_score"],
min_dec_lens=self.share_inputs["min_dec_len"], min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"], eos_token_ids=self.share_inputs["eos_token_id"],
) )

View File

@@ -718,3 +718,164 @@ def test_non_streaming_min_max_token_equals_one(openai_client, capsys):
# Verify usage shows exactly 1 completion token # Verify usage shows exactly 1 completion token
assert hasattr(response, "usage") assert hasattr(response, "usage")
assert response.usage.completion_tokens == 1 assert response.usage.completion_tokens == 1
def test_non_streaming_chat_with_bad_words(openai_client, capsys):
"""
Test bad_words option in non-streaming chat functionality with the local service
"""
response_0 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
top_p=0.0,
max_tokens=10,
stream=False,
)
output_0 = []
assert hasattr(response_0, "choices")
assert len(response_0.choices) > 0
assert hasattr(response_0.choices[0], "message")
assert hasattr(response_0.choices[0].message, "content")
text_split = response_0.choices[0].message.content.split(" ")
for text in text_split:
output_0.append(text)
# add bad words
response_1 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
top_p=0.0,
max_tokens=10,
extra_body={"bad_words": output_0[-5:]},
stream=False,
)
output_1 = []
assert hasattr(response_1, "choices")
assert len(response_1.choices) > 0
assert hasattr(response_1.choices[0], "message")
assert hasattr(response_1.choices[0].message, "content")
text_split = response_1.choices[0].message.content.split(" ")
for text in text_split:
output_1.append(text)
assert output_0 not in output_1
def test_streaming_chat_with_bad_words(openai_client, capsys):
"""
Test bad_words option in streaming chat functionality with the local service
"""
response_0 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
top_p=0.0,
max_tokens=10,
stream=True,
)
output_0 = []
for chunk in response_0:
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "delta")
assert hasattr(chunk.choices[0].delta, "content")
output_0.append(chunk.choices[0].delta.content)
# add bad words
response_1 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
top_p=0.0,
max_tokens=10,
extra_body={"bad_words": output_0[-5:]},
stream=True,
)
output_1 = []
for chunk in response_1:
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "delta")
assert hasattr(chunk.choices[0].delta, "content")
output_1.append(chunk.choices[0].delta.content)
assert output_0 not in output_1
def test_non_streaming_completion_with_bad_words(openai_client, capsys):
"""
Test bad_words option in non-streaming completion functionality with the local service
"""
response_0 = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
top_p=0.0,
max_tokens=10,
stream=False,
)
output_0 = []
assert hasattr(response_0, "choices")
assert len(response_0.choices) > 0
assert hasattr(response_0.choices[0], "text")
text_split = response_0.choices[0].text.split(" ")
for text in text_split:
output_0.append(text)
# add bad words
response_1 = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
top_p=0.0,
max_tokens=10,
extra_body={"bad_words": output_0[-5:]},
stream=False,
)
output_1 = []
assert hasattr(response_1, "choices")
assert len(response_1.choices) > 0
assert hasattr(response_1.choices[0], "text")
text_split = response_1.choices[0].text.split(" ")
for text in text_split:
output_1.append(text)
assert output_0 not in output_1
def test_streaming_completion_with_bad_words(openai_client, capsys):
"""
Test bad_words option in streaming completion functionality with the local service
"""
response_0 = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
top_p=0.0,
max_tokens=10,
stream=True,
)
output_0 = []
for chunk in response_0:
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "text")
output_0.append(chunk.choices[0].text)
# add bad words
response_1 = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
top_p=0.0,
max_tokens=10,
extra_body={"bad_words": output_0[-5:]},
stream=True,
)
output_1 = []
for chunk in response_1:
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "text")
output_1.append(chunk.choices[0].text)
assert output_0 not in output_1