mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[feat] add disable_chat_template in chat api as a substitute for previous raw_request (#3020)
* [feat] add disable_chat_template in chat api as a substitute for previous raw_request * [fix] pre-commit code check
This commit is contained in:
@@ -60,6 +60,7 @@ class Request:
|
|||||||
preprocess_end_time: Optional[float] = None,
|
preprocess_end_time: Optional[float] = None,
|
||||||
multimodal_inputs: Optional[dict] = None,
|
multimodal_inputs: Optional[dict] = None,
|
||||||
multimodal_data: Optional[dict] = None,
|
multimodal_data: Optional[dict] = None,
|
||||||
|
disable_chat_template: bool = False,
|
||||||
disaggregate_info: Optional[dict] = None,
|
disaggregate_info: Optional[dict] = None,
|
||||||
draft_token_ids: Optional[list[int]] = None,
|
draft_token_ids: Optional[list[int]] = None,
|
||||||
guided_json: Optional[Any] = None,
|
guided_json: Optional[Any] = None,
|
||||||
@@ -87,6 +88,7 @@ class Request:
|
|||||||
self.arrival_time = arrival_time
|
self.arrival_time = arrival_time
|
||||||
self.preprocess_start_time = preprocess_start_time
|
self.preprocess_start_time = preprocess_start_time
|
||||||
self.preprocess_end_time = preprocess_end_time
|
self.preprocess_end_time = preprocess_end_time
|
||||||
|
self.disable_chat_template = disable_chat_template
|
||||||
self.disaggregate_info = disaggregate_info
|
self.disaggregate_info = disaggregate_info
|
||||||
|
|
||||||
# speculative method in disaggregate-mode
|
# speculative method in disaggregate-mode
|
||||||
@@ -136,6 +138,7 @@ class Request:
|
|||||||
preprocess_end_time=d.get("preprocess_end_time"),
|
preprocess_end_time=d.get("preprocess_end_time"),
|
||||||
multimodal_inputs=d.get("multimodal_inputs"),
|
multimodal_inputs=d.get("multimodal_inputs"),
|
||||||
multimodal_data=d.get("multimodal_data"),
|
multimodal_data=d.get("multimodal_data"),
|
||||||
|
disable_chat_template=d.get("disable_chat_template"),
|
||||||
disaggregate_info=d.get("disaggregate_info"),
|
disaggregate_info=d.get("disaggregate_info"),
|
||||||
draft_token_ids=d.get("draft_token_ids"),
|
draft_token_ids=d.get("draft_token_ids"),
|
||||||
guided_json=d.get("guided_json", None),
|
guided_json=d.get("guided_json", None),
|
||||||
@@ -180,6 +183,7 @@ class Request:
|
|||||||
"preprocess_end_time": self.preprocess_end_time,
|
"preprocess_end_time": self.preprocess_end_time,
|
||||||
"multimodal_inputs": self.multimodal_inputs,
|
"multimodal_inputs": self.multimodal_inputs,
|
||||||
"multimodal_data": self.multimodal_data,
|
"multimodal_data": self.multimodal_data,
|
||||||
|
"disable_chat_template": self.disable_chat_template,
|
||||||
"disaggregate_info": self.disaggregate_info,
|
"disaggregate_info": self.disaggregate_info,
|
||||||
"draft_token_ids": self.draft_token_ids,
|
"draft_token_ids": self.draft_token_ids,
|
||||||
"enable_thinking": self.enable_thinking,
|
"enable_thinking": self.enable_thinking,
|
||||||
|
@@ -483,6 +483,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
extra_body: Optional[dict] = None
|
extra_body: Optional[dict] = None
|
||||||
return_token_ids: Optional[bool] = False
|
return_token_ids: Optional[bool] = False
|
||||||
prompt_token_ids: Optional[List[int]] = None
|
prompt_token_ids: Optional[List[int]] = None
|
||||||
|
disable_chat_template: Optional[bool] = False
|
||||||
|
|
||||||
response_format: Optional[AnyResponseFormat] = None
|
response_format: Optional[AnyResponseFormat] = None
|
||||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||||
@@ -531,6 +532,11 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
else:
|
else:
|
||||||
assert len(self.messages) > 0
|
assert len(self.messages) > 0
|
||||||
|
|
||||||
|
# If disable_chat_template is set, then the first message in messages will be used as the prompt.
|
||||||
|
if self.disable_chat_template:
|
||||||
|
req_dict["prompt"] = req_dict["messages"][0]["content"]
|
||||||
|
del req_dict["messages"]
|
||||||
|
|
||||||
guided_json_object = None
|
guided_json_object = None
|
||||||
if self.response_format is not None:
|
if self.response_format is not None:
|
||||||
if self.response_format.type == "json_object":
|
if self.response_format.type == "json_object":
|
||||||
|
@@ -662,3 +662,37 @@ def test_streaming_completion_with_prompt_token_ids(openai_client, capsys):
|
|||||||
else:
|
else:
|
||||||
assert hasattr(chunk.usage, "prompt_tokens")
|
assert hasattr(chunk.usage, "prompt_tokens")
|
||||||
assert chunk.usage.prompt_tokens == 9
|
assert chunk.usage.prompt_tokens == 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_streaming_chat_completion_disable_chat_template(openai_client, capsys):
|
||||||
|
"""
|
||||||
|
Test disable_chat_template option in chat functionality with the local service.
|
||||||
|
"""
|
||||||
|
enabled_response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
top_p=0,
|
||||||
|
extra_body={"disable_chat_template": False},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert hasattr(enabled_response, "choices")
|
||||||
|
assert len(enabled_response.choices) > 0
|
||||||
|
|
||||||
|
# from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||||
|
# tokenizer = ErnieBotTokenizer.from_pretrained("PaddlePaddle/ERNIE-4.5-0.3B-Paddle", trust_remote_code=True)
|
||||||
|
# prompt = tokenizer.apply_chat_template([{"role": "user", "content": "Hello, how are you?"}], tokenize=False)
|
||||||
|
prompt = "<|begin_of_sentence|>User: Hello, how are you?\nAssistant: "
|
||||||
|
disabled_response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0,
|
||||||
|
top_p=0,
|
||||||
|
extra_body={"disable_chat_template": True},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert hasattr(disabled_response, "choices")
|
||||||
|
assert len(disabled_response.choices) > 0
|
||||||
|
assert enabled_response.choices[0].message.content == disabled_response.choices[0].message.content
|
||||||
|
Reference in New Issue
Block a user