mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[feat] completion api supports passing input token ids in either prompt
or prompt_token_ids
(#3311)
* [feat] completion api supports passing input token ids in either `prompt` or `prompt_token_ids` * [fix] update comment * [fix] fix type error * [test] add a unittest file for serving api test * [test] try to fix ci error * [chore] rename test function names * [test] try to fix ci error * [test] try to fix ci error * [test] add tests for qwen
This commit is contained in:
@@ -438,7 +438,7 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
max_streaming_response_tokens: Optional[int] = None
|
||||
return_token_ids: Optional[bool] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_dict_for_infer(self, request_id=None, prompt=None):
|
||||
@@ -463,11 +463,11 @@ class CompletionRequest(BaseModel):
|
||||
if prompt is not None:
|
||||
req_dict["prompt"] = prompt
|
||||
|
||||
if "prompt_token_ids" in req_dict:
|
||||
if "prompt" in req_dict:
|
||||
del req_dict["prompt"]
|
||||
else:
|
||||
assert len(prompt) > 0
|
||||
# if "prompt_token_ids" in req_dict:
|
||||
# if "prompt" in req_dict:
|
||||
# del req_dict["prompt"]
|
||||
# else:
|
||||
# assert len(prompt) > 0
|
||||
|
||||
guided_json_object = None
|
||||
if self.response_format is not None:
|
||||
@@ -572,7 +572,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
# doc: start-completion-extra-params
|
||||
# doc: start-chat-completion-extra-params
|
||||
chat_template_kwargs: Optional[dict] = None
|
||||
chat_template: Optional[str] = None
|
||||
reasoning_max_tokens: Optional[int] = None
|
||||
|
Reference in New Issue
Block a user