From c68c3c4b8b1ebe81663f08410ac7f9ec79e99b77 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com> Date: Tue, 26 Aug 2025 11:14:51 +0800 Subject: [PATCH] [Feature] bad words support v1 scheduler and specifiy token ids (#3608) * support bad_words_token_ids * docs * fix test * fix * bad words support kvcache v1 and token ids * fix --- docs/features/sampling.md | 49 +++++++++-- docs/online_serving/README.md | 6 ++ docs/zh/features/sampling.md | 48 +++++++++-- docs/zh/online_serving/README.md | 3 + fastdeploy/engine/engine.py | 3 - fastdeploy/engine/sampling_params.py | 46 +--------- fastdeploy/entrypoints/openai/protocol.py | 2 + fastdeploy/input/ernie_processor.py | 53 ++++++++++++ fastdeploy/input/ernie_vl_processor.py | 6 ++ fastdeploy/input/qwen_vl_processor.py | 6 ++ fastdeploy/input/text_processor.py | 52 ++++++++++++ fastdeploy/worker/gpu_model_runner.py | 10 +++ fastdeploy/worker/metax_model_runner.py | 10 +++ fastdeploy/worker/xpu_model_runner.py | 10 +++ tests/ci_use/EB_Lite/test_EB_Lite_serving.py | 89 +++++++++++++++++++- tests/e2e/test_EB_Lite_serving.py | 89 +++++++++++++++++++- 16 files changed, 420 insertions(+), 62 deletions(-) diff --git a/docs/features/sampling.md b/docs/features/sampling.md index 3a0d22869..4b2774fc2 100644 --- a/docs/features/sampling.md +++ b/docs/features/sampling.md @@ -183,7 +183,7 @@ Used to prevent the model from generating certain specific words during the infe ## Usage Instructions -Include the `bad_words` parameter in the request: +Include the `bad_words` or `bad_words_token_ids` parameter in the request: * Example request with curl: @@ -192,9 +192,22 @@ curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user", "content": "How old are you"} + {"role": "user", "content": "How are you"} ], - "bad_words": ["age", "I"] + "bad_words": [" well", " Today"] +}' +``` + +Equal to + +```bash +curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "messages": [ + {"role": "user", "content": "How are you"} + ], + "bad_words_token_ids": [1622, 25062] }' ``` @@ -203,15 +216,37 @@ curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \ ```python import openai host = "0.0.0.0" -port = "8170" +port = "9222" client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null") response = client.chat.completions.create( model="null", messages=[ - {"role": "system", "content": "I'm a helpful AI assistant."}, + {"role": "user", "content": "Hello, how are you?"}, ], - extra_body={"bad_words": ["you", "me"]}, + extra_body={"bad_words": [" well", " Today"]}, + stream=True, +) +for chunk in response: + if chunk.choices[0].delta: + print(chunk.choices[0].delta.content, end='') +print('\n') +``` + +Equal to + +```python +import openai +host = "0.0.0.0" +port = "9222" +client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null") + +response = client.chat.completions.create( + model="null", + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + extra_body={"bad_words_token_ids": [1622, 25062]}, stream=True, ) for chunk in response: @@ -223,3 +258,5 @@ print('\n') ## Parameter Description `bad_words`: List of forbidden words. Type: list of str. Each word must be a single token. + +`bad_words_token_ids`: List of forbidden token ids. Type: list of int. diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index 53a8b60c8..178393489 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -153,6 +153,9 @@ include_stop_str_in_output: Optional[bool] = False bad_words: Optional[List[str]] = None # List of forbidden words (e.g., sensitive words) that the model should avoid generating (default None means no restriction). +bad_words_token_ids: Optional[List[int]] = None +# List of forbidden token ids that the model should avoid generating (default None means no restriction). + repetition_penalty: Optional[float] = None # Repetition penalty coefficient, reducing the probability of repeating already generated tokens (`>1.0` suppresses repetition, `<1.0` encourages repetition, default None means disabled). ``` @@ -340,6 +343,9 @@ include_stop_str_in_output: Optional[bool] = False bad_words: Optional[List[str]] = None # List of forbidden words (e.g., sensitive words) that the model should avoid generating (default None means no restriction). +bad_words_token_ids: Optional[List[int]] = None +# List of forbidden token ids that the model should avoid generating (default None means no restriction). + repetition_penalty: Optional[float] = None # Repetition penalty coefficient, reducing the probability of repeating already generated tokens (`>1.0` suppresses repetition, `<1.0` encourages repetition, default None means disabled). ``` diff --git a/docs/zh/features/sampling.md b/docs/zh/features/sampling.md index 24cc003b5..51464515d 100644 --- a/docs/zh/features/sampling.md +++ b/docs/zh/features/sampling.md @@ -183,7 +183,7 @@ print('\n') ## 使用说明 -请求中加入bad_words参数: +可以在请求中加入bad_words参数,也可以加入bad_words_token_ids参数 * 使用 curl 命令发送用户请求示例如下: @@ -192,9 +192,22 @@ curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user", "content": "How old are you"} + {"role": "user", "content": "How are you"} ], - "bad_words": ["age", "I"] + "bad_words": [" well", " Today"] +}' +``` + +等价于 + +```bash +curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "messages": [ + {"role": "user", "content": "How are you"} + ], + "bad_words_token_ids": [1622, 25062] }' ``` @@ -203,15 +216,37 @@ curl -X POST "http://0.0.0.0:9222/v1/chat/completions" \ ```python import openai host = "0.0.0.0" -port = "8170" +port = "9222" client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null") response = client.chat.completions.create( model="null", messages=[ - {"role": "system", "content": "I'm a helpful AI assistant."}, + {"role": "user", "content": "Hello, how are you?"}, ], - extra_body={"bad_words": ["you", "me"]}, + extra_body={"bad_words": [" well", " Today"]}, + stream=True, +) +for chunk in response: + if chunk.choices[0].delta: + print(chunk.choices[0].delta.content, end='') +print('\n') +``` + +等价于 + +```python +import openai +host = "0.0.0.0" +port = "9222" +client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null") + +response = client.chat.completions.create( + model="null", + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + extra_body={"bad_words_token_ids": [1622, 25062]}, stream=True, ) for chunk in response: @@ -223,3 +258,4 @@ print('\n') ## 参数说明 * `bad_words`: 禁止生成的词列表。list类型,每个元素为str类型。仅支持每个元素为单个token。 +* `bad_words_token_ids`: 禁止生成的token id列表。list类型,每个元素为int类型。 diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index 45e2168d2..5e72dd5ed 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -153,6 +153,9 @@ include_stop_str_in_output: Optional[bool] = False bad_words: Optional[List[str]] = None # 禁止生成的词汇列表(例如敏感词),模型会避免输出这些词(默认 None 表示不限制)。 +bad_words_token_ids: Optional[List[int]] = None +# 禁止生成的token id列表,模型会避免输出这些词(默认 None 表示不限制)。 + repetition_penalty: Optional[float] = None # 重复惩罚系数,降低已生成 token 的重复概率(>1.0 抑制重复,<1.0 鼓励重复,默认 None 表示禁用)。 ``` diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index d09f02122..8b49f2659 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -461,7 +461,6 @@ class LLMEngine: request = Request.from_dict(task) llm_logger.info(f"Receive request {request}") if sampling_params is not None: - sampling_params.update_from_tokenizer(self.data_processor.tokenizer) request.sampling_params = sampling_params request.preprocess_start_time = time.time() @@ -762,8 +761,6 @@ class LLMEngine: for task in tasks: start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) - if task.sampling_params.bad_words is not None: - task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer) self.resource_manager.check_and_free_block_tables() diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index f95f09bd5..423434857 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -20,8 +20,6 @@ import random from dataclasses import dataclass, fields from typing import Any, List, Optional, Union -from fastdeploy.utils import llm_logger as logger - @dataclass class SamplingParams: @@ -102,7 +100,7 @@ class SamplingParams: temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False bad_words: Optional[List[str]] = None - _bad_words_token_ids: Optional[List[int]] = None + bad_words_token_ids: Optional[List[int]] = None @classmethod def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams: @@ -134,6 +132,7 @@ class SamplingParams: min_tokens=1, logprobs=None, bad_words=None, + bad_words_token_ids=None, ) -> SamplingParams: """Create instance from command line arguments""" return cls( @@ -154,6 +153,7 @@ class SamplingParams: min_tokens=min_tokens, logprobs=logprobs, bad_words=bad_words, + bad_words_token_ids=bad_words_token_ids, ) def __post_init__(self): @@ -206,46 +206,6 @@ class SamplingParams: if not 0 <= self.seed <= 922337203685477580: raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.") - def update_from_tokenizer(self, tokenizer): - """Support bad words""" - if self.bad_words is None: - return - self._bad_words_token_ids = [] - for bad_word in self.bad_words: - # To prohibit words both at the beginning - # and in the middle of text - # (related to add_prefix_space tokenizer parameter) - for add_prefix_space in [False, True]: - prefix = " " if add_prefix_space else "" - prompt = prefix + bad_word.lstrip() - prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"] - - if len(prompt_token_ids) != 1: - if not add_prefix_space: - logger.warning( - f"Skip bad_words: <{prompt}>." - f"Bad words should be a single token." - f"Got tokens: {prompt_token_ids}." - ) - continue - - if prompt_token_ids[0] > tokenizer.vocab_size: - if not add_prefix_space: - logger.warning( - f"Skip bad_words: <{prompt}>." - f"All token id values should be satisfying:" - f" 0 <= token_id < {tokenizer.vocab_size}." - f"Got token: {prompt_token_ids}." - ) - continue - - if prompt_token_ids not in self._bad_words_token_ids: - self._bad_words_token_ids.extend(prompt_token_ids) - - @property - def bad_words_token_ids(self) -> Optional[List[list[int]]]: - return self._bad_words_token_ids - @dataclass class BeamSearchParams: diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 733701eea..882c0322b 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -426,6 +426,7 @@ class CompletionRequest(BaseModel): min_tokens: Optional[int] = None include_stop_str_in_output: Optional[bool] = False bad_words: Optional[List[str]] = None + bad_words_token_ids: Optional[List[int]] = None # doc: end-completion-sampling-params # doc: start-completion-extra-params @@ -566,6 +567,7 @@ class ChatCompletionRequest(BaseModel): min_tokens: Optional[int] = None include_stop_str_in_output: Optional[bool] = False bad_words: Optional[List[str]] = None + bad_words_token_ids: Optional[List[int]] = None repetition_penalty: Optional[float] = None stop_token_ids: Optional[List[int]] = Field(default_factory=list) # doc: end-chat-completion-sampling-params diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 2772c82ff..db397dbd0 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -97,6 +97,12 @@ class ErnieProcessor(BaseDataProcessor): request.set("stop_token_ids", stop_seqs) request.set("stop_seqs_len", stop_seqs_len) + bad_words = request.get("bad_words") + bad_words_token_ids = request.get("bad_words_token_ids") + if bad_words: + bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) + request["bad_words_token_ids"] = bad_words_token_ids + if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt is None and request.messages is None: raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.") @@ -160,6 +166,13 @@ class ErnieProcessor(BaseDataProcessor): request["stop_token_ids"] = stop_seqs request["stop_seqs_len"] = stop_seqs_len + # processing bad_words + bad_words = request.get("bad_words") + bad_words_token_ids = request.get("bad_words_token_ids") + if bad_words: + bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) + request["bad_words_token_ids"] = bad_words_token_ids + # processing prompt_token_ids if not request.get("prompt_token_ids"): if request.get("prompt") is None and request.get("messages") is None: @@ -481,3 +494,43 @@ class ErnieProcessor(BaseDataProcessor): def process_logprob_response(self, token_ids, **kwargs): full_text = self.tokenizer.decode(token_ids, **kwargs) return full_text + + def update_bad_words(self, bad_words, bad_words_token_ids): + """Support bad words""" + + token_ids = bad_words_token_ids + + if token_ids is None: + token_ids = [] + for bad_word in bad_words: + # To prohibit words both at the beginning + # and in the middle of text + # (related to add_prefix_space tokenizer parameter) + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + prompt = prefix + bad_word.lstrip() + prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt)) + data_processor_logger.debug(f"processed bad_words: {prompt}, {prompt_token_ids}") + + if len(prompt_token_ids) != 1: + if not add_prefix_space: + data_processor_logger.warning( + f"Skip bad_words: <{prompt}>." + f"Bad words should be a single token." + f"Got tokens: {prompt_token_ids}." + ) + continue + + if prompt_token_ids[0] > self.tokenizer.vocab_size: + if not add_prefix_space: + data_processor_logger.warning( + f"Skip bad_words: <{prompt}>." + f"All token id values should be satisfying:" + f" 0 <= token_id < {self.tokenizer.vocab_size}." + f"Got token: {prompt_token_ids}." + ) + continue + + if prompt_token_ids not in token_ids: + token_ids.extend(prompt_token_ids) + return token_ids diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py index 606844fc7..82f11acd0 100644 --- a/fastdeploy/input/ernie_vl_processor.py +++ b/fastdeploy/input/ernie_vl_processor.py @@ -208,6 +208,12 @@ class ErnieMoEVLProcessor(ErnieProcessor): request["stop_token_ids"] = stop_seqs request["stop_seqs_len"] = stop_seqs_len + bad_words = request.get("bad_words") + bad_words_token_ids = request.get("bad_words_token_ids") + if bad_words: + bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) + request["bad_words_token_ids"] = bad_words_token_ids + if request.get("prompt"): multimodal_data = request.get("multimodal_data") if multimodal_data is None: diff --git a/fastdeploy/input/qwen_vl_processor.py b/fastdeploy/input/qwen_vl_processor.py index 8f6a8a9d7..9e8afa3bf 100644 --- a/fastdeploy/input/qwen_vl_processor.py +++ b/fastdeploy/input/qwen_vl_processor.py @@ -212,6 +212,12 @@ class QwenVLProcessor(TextProcessor): request["stop_token_ids"] = stop_seqs request["stop_seqs_len"] = stop_seqs_len + bad_words = request.get("bad_words") + bad_words_token_ids = request.get("bad_words_token_ids") + if bad_words: + bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) + request["bad_words_token_ids"] = bad_words_token_ids + if request.get("prompt"): multimodal_data = request.get("multimodal_data") if multimodal_data is None: diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index 551975760..7022f5cac 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -214,6 +214,12 @@ class DataProcessor(BaseDataProcessor): request.set("stop_token_ids", stop_seqs) request.set("stop_seqs_len", stop_seqs_len) + bad_words = request.get("bad_words") + bad_words_token_ids = request.get("bad_words_token_ids") + if bad_words: + bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) + request["bad_words_token_ids"] = bad_words_token_ids + if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt is not None: request.prompt_token_ids = self.text2ids(request.prompt, max_model_len) @@ -270,6 +276,13 @@ class DataProcessor(BaseDataProcessor): request["stop_token_ids"] = stop_seqs request["stop_seqs_len"] = stop_seqs_len + # processing bad_words + bad_words = request.get("bad_words") + bad_words_token_ids = request.get("bad_words_token_ids") + if bad_words: + bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) + request["bad_words_token_ids"] = bad_words_token_ids + data_processor_logger.info(f"Processing request {request}") # processing prompt_token_ids if not request.get("prompt_token_ids"): @@ -652,3 +665,42 @@ class DataProcessor(BaseDataProcessor): stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False) data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}") return stop_seqs, stop_seqs_len + + def update_bad_words(self, bad_words, bad_words_token_ids): + """Support bad words""" + + token_ids = bad_words_token_ids + + if token_ids is None: + token_ids = [] + for bad_word in bad_words: + # To prohibit words both at the beginning + # and in the middle of text + # (related to add_prefix_space tokenizer parameter) + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + prompt = prefix + bad_word.lstrip() + prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt)) + + if len(prompt_token_ids) != 1: + if not add_prefix_space: + data_processor_logger.warning( + f"Skip bad_words: <{prompt}>." + f"Bad words should be a single token." + f"Got tokens: {prompt_token_ids}." + ) + continue + + if prompt_token_ids[0] > self.tokenizer.vocab_size: + if not add_prefix_space: + data_processor_logger.warning( + f"Skip bad_words: <{prompt}>." + f"All token id values should be satisfying:" + f" 0 <= token_id < {self.tokenizer.vocab_size}." + f"Got token: {prompt_token_ids}." + ) + continue + + if prompt_token_ids not in token_ids: + token_ids.extend(prompt_token_ids) + return token_ids diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2f466e3eb..75a57f608 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -339,6 +339,16 @@ class GPUModelRunner(ModelRunnerBase): if request.get("seed") is not None: self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0: + bad_words_len = len(request.get("bad_words_token_ids")) + self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len + self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array( + request.get("bad_words_token_ids"), dtype="int64" + ) + else: + self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1 + self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64") + if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index a9aad7970..8b710923a 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -322,6 +322,16 @@ class MetaxModelRunner(ModelRunnerBase): if request.get("seed") is not None: self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0: + bad_words_len = len(request.get("bad_words_token_ids")) + self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len + self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array( + request.get("bad_words_token_ids"), dtype="int64" + ) + else: + self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1 + self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64") + if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index fef9ca127..55ac4beb5 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -455,6 +455,16 @@ class XPUModelRunner(ModelRunnerBase): if request.get("seed") is not None: self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0: + bad_words_len = len(request.get("bad_words_token_ids")) + self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len + self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array( + request.get("bad_words_token_ids"), dtype="int64" + ) + else: + self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1 + self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64") + if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): diff --git a/tests/ci_use/EB_Lite/test_EB_Lite_serving.py b/tests/ci_use/EB_Lite/test_EB_Lite_serving.py index 3a771a19d..24d2c5896 100644 --- a/tests/ci_use/EB_Lite/test_EB_Lite_serving.py +++ b/tests/ci_use/EB_Lite/test_EB_Lite_serving.py @@ -847,7 +847,24 @@ def test_non_streaming_chat_with_bad_words(openai_client, capsys): assert hasattr(response_1.choices[0], "message") assert hasattr(response_1.choices[0].message, "completion_token_ids") assert isinstance(response_1.choices[0].message.completion_token_ids, list) + + response_2 = openai_client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=1, + top_p=0.0, + max_tokens=20, + extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True}, + stream=False, + ) + assert hasattr(response_2, "choices") + assert len(response_2.choices) > 0 + assert hasattr(response_2.choices[0], "message") + assert hasattr(response_2.choices[0].message, "completion_token_ids") + assert isinstance(response_2.choices[0].message.completion_token_ids, list) + assert not any(ids in response_1.choices[0].message.completion_token_ids for ids in bad_token_ids) + assert not any(ids in response_2.choices[0].message.completion_token_ids for ids in bad_token_ids) def test_streaming_chat_with_bad_words(openai_client, capsys): @@ -906,7 +923,34 @@ def test_streaming_chat_with_bad_words(openai_client, capsys): assert isinstance(chunk.choices[0].delta.completion_token_ids, list) output_tokens_1.append(chunk.choices[0].delta.content) output_ids_1.extend(chunk.choices[0].delta.completion_token_ids) + + response_2 = openai_client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=1, + top_p=0.0, + max_tokens=20, + extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True}, + stream=True, + ) + output_tokens_2 = [] + output_ids_2 = [] + is_first_chunk = True + for chunk in response_2: + assert hasattr(chunk, "choices") + assert len(chunk.choices) > 0 + assert hasattr(chunk.choices[0], "delta") + assert hasattr(chunk.choices[0].delta, "content") + assert hasattr(chunk.choices[0].delta, "completion_token_ids") + if is_first_chunk: + is_first_chunk = False + else: + assert isinstance(chunk.choices[0].delta.completion_token_ids, list) + output_tokens_2.append(chunk.choices[0].delta.content) + output_ids_2.extend(chunk.choices[0].delta.completion_token_ids) + assert not any(ids in output_ids_1 for ids in bad_token_ids) + assert not any(ids in output_ids_2 for ids in bad_token_ids) def test_non_streaming_completion_with_bad_words(openai_client, capsys): @@ -956,9 +1000,25 @@ def test_non_streaming_completion_with_bad_words(openai_client, capsys): ) assert hasattr(response_1, "choices") assert len(response_1.choices) > 0 - assert hasattr(response_0.choices[0], "completion_token_ids") - assert isinstance(response_0.choices[0].completion_token_ids, list) + assert hasattr(response_1.choices[0], "completion_token_ids") + assert isinstance(response_1.choices[0].completion_token_ids, list) + + response_2 = openai_client.completions.create( + model="default", + prompt="Hello, how are you?", + temperature=1, + top_p=0.0, + max_tokens=20, + extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True}, + stream=False, + ) + assert hasattr(response_2, "choices") + assert len(response_2.choices) > 0 + assert hasattr(response_2.choices[0], "completion_token_ids") + assert isinstance(response_2.choices[0].completion_token_ids, list) + assert not any(ids in response_1.choices[0].completion_token_ids for ids in bad_token_ids) + assert not any(ids in response_2.choices[0].completion_token_ids for ids in bad_token_ids) def test_streaming_completion_with_bad_words(openai_client, capsys): @@ -1013,7 +1073,32 @@ def test_streaming_completion_with_bad_words(openai_client, capsys): assert hasattr(chunk.choices[0], "completion_token_ids") output_tokens_1.append(chunk.choices[0].text) output_ids_1.extend(chunk.choices[0].completion_token_ids) + # add bad words token ids + response_2 = openai_client.completions.create( + model="default", + prompt="Hello, how are you?", + temperature=1, + top_p=0.0, + max_tokens=20, + extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True}, + stream=True, + ) + output_tokens_2 = [] + output_ids_2 = [] + is_first_chunk = True + for chunk in response_2: + if is_first_chunk: + is_first_chunk = False + else: + assert hasattr(chunk, "choices") + assert len(chunk.choices) > 0 + assert hasattr(chunk.choices[0], "text") + assert hasattr(chunk.choices[0], "completion_token_ids") + output_tokens_2.append(chunk.choices[0].text) + output_ids_2.extend(chunk.choices[0].completion_token_ids) + assert not any(ids in output_ids_1 for ids in bad_token_ids) + assert not any(ids in output_ids_2 for ids in bad_token_ids) def test_profile_reset_block_num(): diff --git a/tests/e2e/test_EB_Lite_serving.py b/tests/e2e/test_EB_Lite_serving.py index 62f40b571..452a809ca 100644 --- a/tests/e2e/test_EB_Lite_serving.py +++ b/tests/e2e/test_EB_Lite_serving.py @@ -842,7 +842,24 @@ def test_non_streaming_chat_with_bad_words(openai_client, capsys): assert hasattr(response_1.choices[0], "message") assert hasattr(response_1.choices[0].message, "completion_token_ids") assert isinstance(response_1.choices[0].message.completion_token_ids, list) + + response_2 = openai_client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=1, + top_p=0.0, + max_tokens=20, + extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True}, + stream=False, + ) + assert hasattr(response_2, "choices") + assert len(response_2.choices) > 0 + assert hasattr(response_2.choices[0], "message") + assert hasattr(response_2.choices[0].message, "completion_token_ids") + assert isinstance(response_2.choices[0].message.completion_token_ids, list) + assert not any(ids in response_1.choices[0].message.completion_token_ids for ids in bad_token_ids) + assert not any(ids in response_2.choices[0].message.completion_token_ids for ids in bad_token_ids) def test_streaming_chat_with_bad_words(openai_client, capsys): @@ -901,7 +918,34 @@ def test_streaming_chat_with_bad_words(openai_client, capsys): assert isinstance(chunk.choices[0].delta.completion_token_ids, list) output_tokens_1.append(chunk.choices[0].delta.content) output_ids_1.extend(chunk.choices[0].delta.completion_token_ids) + + response_2 = openai_client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=1, + top_p=0.0, + max_tokens=20, + extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True}, + stream=True, + ) + output_tokens_2 = [] + output_ids_2 = [] + is_first_chunk = True + for chunk in response_2: + assert hasattr(chunk, "choices") + assert len(chunk.choices) > 0 + assert hasattr(chunk.choices[0], "delta") + assert hasattr(chunk.choices[0].delta, "content") + assert hasattr(chunk.choices[0].delta, "completion_token_ids") + if is_first_chunk: + is_first_chunk = False + else: + assert isinstance(chunk.choices[0].delta.completion_token_ids, list) + output_tokens_2.append(chunk.choices[0].delta.content) + output_ids_2.extend(chunk.choices[0].delta.completion_token_ids) + assert not any(ids in output_ids_1 for ids in bad_token_ids) + assert not any(ids in output_ids_2 for ids in bad_token_ids) def test_non_streaming_completion_with_bad_words(openai_client, capsys): @@ -951,9 +995,25 @@ def test_non_streaming_completion_with_bad_words(openai_client, capsys): ) assert hasattr(response_1, "choices") assert len(response_1.choices) > 0 - assert hasattr(response_0.choices[0], "completion_token_ids") - assert isinstance(response_0.choices[0].completion_token_ids, list) + assert hasattr(response_1.choices[0], "completion_token_ids") + assert isinstance(response_1.choices[0].completion_token_ids, list) + + response_2 = openai_client.completions.create( + model="default", + prompt="Hello, how are you?", + temperature=1, + top_p=0.0, + max_tokens=20, + extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True}, + stream=False, + ) + assert hasattr(response_2, "choices") + assert len(response_2.choices) > 0 + assert hasattr(response_2.choices[0], "completion_token_ids") + assert isinstance(response_2.choices[0].completion_token_ids, list) + assert not any(ids in response_1.choices[0].completion_token_ids for ids in bad_token_ids) + assert not any(ids in response_2.choices[0].completion_token_ids for ids in bad_token_ids) def test_streaming_completion_with_bad_words(openai_client, capsys): @@ -1008,7 +1068,32 @@ def test_streaming_completion_with_bad_words(openai_client, capsys): assert hasattr(chunk.choices[0], "completion_token_ids") output_tokens_1.append(chunk.choices[0].text) output_ids_1.extend(chunk.choices[0].completion_token_ids) + # add bad words token ids + response_2 = openai_client.completions.create( + model="default", + prompt="Hello, how are you?", + temperature=1, + top_p=0.0, + max_tokens=20, + extra_body={"bad_words_token_ids": bad_token_ids, "return_token_ids": True}, + stream=True, + ) + output_tokens_2 = [] + output_ids_2 = [] + is_first_chunk = True + for chunk in response_2: + if is_first_chunk: + is_first_chunk = False + else: + assert hasattr(chunk, "choices") + assert len(chunk.choices) > 0 + assert hasattr(chunk.choices[0], "text") + assert hasattr(chunk.choices[0], "completion_token_ids") + output_tokens_2.append(chunk.choices[0].text) + output_ids_2.extend(chunk.choices[0].completion_token_ids) + assert not any(ids in output_ids_1 for ids in bad_token_ids) + assert not any(ids in output_ids_2 for ids in bad_token_ids) def test_profile_reset_block_num():