mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] support bad_words (#3055)
* support bad_words * support online infer bad_words * update * add CI test * update * update * update --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -180,7 +180,7 @@ void token_penalty_multi_scores_kernel(
|
||||
int64_t token_num = shape[0];
|
||||
int64_t length = shape[1];
|
||||
int64_t length_id = pre_ids.shape()[1];
|
||||
int64_t length_bad_words = bad_tokens.shape()[0];
|
||||
int64_t length_bad_words = bad_tokens.shape()[1];
|
||||
|
||||
int64_t end_length = eos_token_id.shape()[0];
|
||||
|
||||
|
@@ -171,7 +171,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
|
||||
|
||||
int64_t vocab_size = shape[1];
|
||||
int64_t max_dec_len = pre_ids.shape()[1];
|
||||
int64_t bad_words_len = bad_tokens.shape()[0];
|
||||
int64_t bad_words_len = bad_tokens.shape()[1];
|
||||
int64_t eos_len = eos_token_id.shape()[0];
|
||||
int64_t max_model_len = prompt_ids.shape()[1];
|
||||
|
||||
|
@@ -491,6 +491,7 @@ class LLMEngine:
|
||||
request = Request.from_dict(task)
|
||||
llm_logger.info(f"Receive request {request}")
|
||||
if sampling_params is not None:
|
||||
sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
||||
request.sampling_params = sampling_params
|
||||
request.preprocess_start_time = time.time()
|
||||
|
||||
@@ -747,6 +748,8 @@ class LLMEngine:
|
||||
"""
|
||||
for task in tasks:
|
||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||
if task.sampling_params.bad_words is not None:
|
||||
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
||||
# TODO 返回至 scheduler
|
||||
if allocated:
|
||||
current_tasks = []
|
||||
|
@@ -20,6 +20,8 @@ import random
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from fastdeploy.utils import llm_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
@@ -97,6 +99,7 @@ class SamplingParams:
|
||||
min_tokens: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
bad_words: Optional[List[str]] = None
|
||||
_bad_words_token_ids: Optional[List[int]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
|
||||
@@ -201,11 +204,42 @@ class SamplingParams:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
def update_from_tokenizer(self, tokenizer):
|
||||
"""
|
||||
# TODO: Implement stop tokens and bad words support
|
||||
# Currently stop tokens and bad words are not supported yet
|
||||
"""
|
||||
pass
|
||||
"""Support bad words"""
|
||||
if self.bad_words is None:
|
||||
return
|
||||
self._bad_words_token_ids = []
|
||||
for bad_word in self.bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"]
|
||||
|
||||
if len(prompt_token_ids) != 1:
|
||||
logger.warning(
|
||||
f"Skip bad_words: {prompt}."
|
||||
f"Bad words should be a single token."
|
||||
f"Got tokens: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids[0] > tokenizer.vocab_size:
|
||||
logger.warning(
|
||||
f"Skip bad_words: {prompt}."
|
||||
f"All token id values should be satisfying:"
|
||||
f" 0 <= token_id < {tokenizer.vocab_size}."
|
||||
f"Got token: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids not in self._bad_words_token_ids:
|
||||
self._bad_words_token_ids.extend(prompt_token_ids)
|
||||
|
||||
@property
|
||||
def bad_words_token_ids(self) -> Optional[List[list[int]]]:
|
||||
return self._bad_words_token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -349,6 +349,7 @@ class CompletionRequest(BaseModel):
|
||||
extra_body: Optional[dict] = None
|
||||
return_token_ids: Optional[bool] = False
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
bad_words: Optional[List[str]] = None
|
||||
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
@@ -484,6 +485,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
return_token_ids: Optional[bool] = False
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
disable_chat_template: Optional[bool] = False
|
||||
bad_words: Optional[List[str]] = None
|
||||
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
|
@@ -270,6 +270,14 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
|
||||
if request.get("bad_words_token_ids") is not None:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
if bad_words_len > 0:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
@@ -382,7 +390,8 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||
|
||||
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
@@ -511,6 +520,9 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
# Initialize forward meta data
|
||||
self.initialize_forward_meta()
|
||||
|
||||
@@ -528,7 +540,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
presence_penalties=self.share_inputs["presence_score"],
|
||||
repetition_penalties=self.share_inputs["penalty_score"],
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
)
|
||||
|
@@ -448,6 +448,14 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
|
||||
if request.get("bad_words_token_ids") is not None:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
if bad_words_len > 0:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
@@ -567,7 +575,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||
|
||||
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
@@ -733,6 +742,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
# Initialize forward meta data
|
||||
self.initialize_forward_meta()
|
||||
|
||||
@@ -750,7 +762,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
presence_penalties=self.share_inputs["presence_score"],
|
||||
repetition_penalties=self.share_inputs["penalty_score"],
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
enable_early_stop=self.enable_early_stop,
|
||||
|
@@ -242,6 +242,14 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
|
||||
if request.get("bad_words_token_ids") is not None:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
if bad_words_len > 0:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
@@ -347,7 +355,8 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||
|
||||
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
@@ -484,6 +493,9 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
# Initialize forward meta data
|
||||
self.initialize_forward_meta()
|
||||
|
||||
@@ -500,7 +512,7 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
presence_penalties=self.share_inputs["presence_score"],
|
||||
repetition_penalties=self.share_inputs["penalty_score"],
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
)
|
||||
|
||||
|
@@ -506,6 +506,14 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
|
||||
if request.get("bad_words_token_ids") is not None:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
if bad_words_len > 0:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
@@ -574,7 +582,8 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||
|
||||
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
@@ -652,6 +661,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
)
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
self.forward_meta.attn_backend = self.attn_backends[0]
|
||||
self.initialize_attention_backend()
|
||||
|
||||
@@ -667,7 +679,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
presence_penalties=self.share_inputs["presence_score"],
|
||||
repetition_penalties=self.share_inputs["penalty_score"],
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
)
|
||||
|
||||
|
@@ -718,3 +718,164 @@ def test_non_streaming_min_max_token_equals_one(openai_client, capsys):
|
||||
# Verify usage shows exactly 1 completion token
|
||||
assert hasattr(response, "usage")
|
||||
assert response.usage.completion_tokens == 1
|
||||
|
||||
|
||||
def test_non_streaming_chat_with_bad_words(openai_client, capsys):
|
||||
"""
|
||||
Test bad_words option in non-streaming chat functionality with the local service
|
||||
"""
|
||||
response_0 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=1,
|
||||
top_p=0.0,
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
output_0 = []
|
||||
assert hasattr(response_0, "choices")
|
||||
assert len(response_0.choices) > 0
|
||||
assert hasattr(response_0.choices[0], "message")
|
||||
assert hasattr(response_0.choices[0].message, "content")
|
||||
|
||||
text_split = response_0.choices[0].message.content.split(" ")
|
||||
for text in text_split:
|
||||
output_0.append(text)
|
||||
|
||||
# add bad words
|
||||
response_1 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=1,
|
||||
top_p=0.0,
|
||||
max_tokens=10,
|
||||
extra_body={"bad_words": output_0[-5:]},
|
||||
stream=False,
|
||||
)
|
||||
output_1 = []
|
||||
assert hasattr(response_1, "choices")
|
||||
assert len(response_1.choices) > 0
|
||||
assert hasattr(response_1.choices[0], "message")
|
||||
assert hasattr(response_1.choices[0].message, "content")
|
||||
text_split = response_1.choices[0].message.content.split(" ")
|
||||
for text in text_split:
|
||||
output_1.append(text)
|
||||
assert output_0 not in output_1
|
||||
|
||||
|
||||
def test_streaming_chat_with_bad_words(openai_client, capsys):
|
||||
"""
|
||||
Test bad_words option in streaming chat functionality with the local service
|
||||
"""
|
||||
response_0 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=1,
|
||||
top_p=0.0,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
output_0 = []
|
||||
for chunk in response_0:
|
||||
assert hasattr(chunk, "choices")
|
||||
assert len(chunk.choices) > 0
|
||||
assert hasattr(chunk.choices[0], "delta")
|
||||
assert hasattr(chunk.choices[0].delta, "content")
|
||||
output_0.append(chunk.choices[0].delta.content)
|
||||
|
||||
# add bad words
|
||||
response_1 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=1,
|
||||
top_p=0.0,
|
||||
max_tokens=10,
|
||||
extra_body={"bad_words": output_0[-5:]},
|
||||
stream=True,
|
||||
)
|
||||
output_1 = []
|
||||
for chunk in response_1:
|
||||
assert hasattr(chunk, "choices")
|
||||
assert len(chunk.choices) > 0
|
||||
assert hasattr(chunk.choices[0], "delta")
|
||||
assert hasattr(chunk.choices[0].delta, "content")
|
||||
output_1.append(chunk.choices[0].delta.content)
|
||||
assert output_0 not in output_1
|
||||
|
||||
|
||||
def test_non_streaming_completion_with_bad_words(openai_client, capsys):
|
||||
"""
|
||||
Test bad_words option in non-streaming completion functionality with the local service
|
||||
"""
|
||||
response_0 = openai_client.completions.create(
|
||||
model="default",
|
||||
prompt="Hello, how are you?",
|
||||
temperature=1,
|
||||
top_p=0.0,
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
output_0 = []
|
||||
assert hasattr(response_0, "choices")
|
||||
assert len(response_0.choices) > 0
|
||||
assert hasattr(response_0.choices[0], "text")
|
||||
text_split = response_0.choices[0].text.split(" ")
|
||||
for text in text_split:
|
||||
output_0.append(text)
|
||||
|
||||
# add bad words
|
||||
response_1 = openai_client.completions.create(
|
||||
model="default",
|
||||
prompt="Hello, how are you?",
|
||||
temperature=1,
|
||||
top_p=0.0,
|
||||
max_tokens=10,
|
||||
extra_body={"bad_words": output_0[-5:]},
|
||||
stream=False,
|
||||
)
|
||||
output_1 = []
|
||||
assert hasattr(response_1, "choices")
|
||||
assert len(response_1.choices) > 0
|
||||
assert hasattr(response_1.choices[0], "text")
|
||||
text_split = response_1.choices[0].text.split(" ")
|
||||
for text in text_split:
|
||||
output_1.append(text)
|
||||
assert output_0 not in output_1
|
||||
|
||||
|
||||
def test_streaming_completion_with_bad_words(openai_client, capsys):
|
||||
"""
|
||||
Test bad_words option in streaming completion functionality with the local service
|
||||
"""
|
||||
response_0 = openai_client.completions.create(
|
||||
model="default",
|
||||
prompt="Hello, how are you?",
|
||||
temperature=1,
|
||||
top_p=0.0,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
output_0 = []
|
||||
for chunk in response_0:
|
||||
assert hasattr(chunk, "choices")
|
||||
assert len(chunk.choices) > 0
|
||||
assert hasattr(chunk.choices[0], "text")
|
||||
output_0.append(chunk.choices[0].text)
|
||||
|
||||
# add bad words
|
||||
response_1 = openai_client.completions.create(
|
||||
model="default",
|
||||
prompt="Hello, how are you?",
|
||||
temperature=1,
|
||||
top_p=0.0,
|
||||
max_tokens=10,
|
||||
extra_body={"bad_words": output_0[-5:]},
|
||||
stream=True,
|
||||
)
|
||||
output_1 = []
|
||||
for chunk in response_1:
|
||||
assert hasattr(chunk, "choices")
|
||||
assert len(chunk.choices) > 0
|
||||
assert hasattr(chunk.choices[0], "text")
|
||||
output_1.append(chunk.choices[0].text)
|
||||
assert output_0 not in output_1
|
||||
|
Reference in New Issue
Block a user