mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-15 21:20:53 +08:00
[feat] completion api supports passing input token ids in either prompt
or prompt_token_ids
(#3311)
* [feat] completion api supports passing input token ids in either `prompt` or `prompt_token_ids` * [fix] update comment * [fix] fix type error * [test] add a unittest file for serving api test * [test] try to fix ci error * [chore] rename test function names * [test] try to fix ci error * [test] try to fix ci error * [test] add tests for qwen
This commit is contained in:
@@ -81,25 +81,40 @@ class OpenAIServingCompletion:
|
||||
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
|
||||
else:
|
||||
request_id = f"cmpl-{uuid.uuid4()}"
|
||||
api_server_logger.info(f"initialize request {request_id}")
|
||||
api_server_logger.info(f"Initialize request {request_id}: {request}")
|
||||
request_prompt_ids = None
|
||||
request_prompts = None
|
||||
|
||||
# Handle prompt and prompt_token_ids
|
||||
try:
|
||||
if isinstance(request.prompt, str):
|
||||
request_prompts = [request.prompt]
|
||||
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
|
||||
request_prompt_ids = [request.prompt]
|
||||
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
|
||||
request_prompts = request.prompt
|
||||
elif isinstance(request.prompt, list):
|
||||
for item in request.prompt:
|
||||
if isinstance(item, list) and all(isinstance(x, int) for x in item):
|
||||
continue
|
||||
else:
|
||||
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
|
||||
request_prompt_ids = request.prompt
|
||||
if request.prompt_token_ids is not None: # let `prompt_token_ids` support batch inference
|
||||
assert len(request.prompt_token_ids) > 0, "prompt_token_ids should not be an empty list"
|
||||
if isinstance(request.prompt_token_ids[0], list):
|
||||
request_prompt_ids = request.prompt_token_ids
|
||||
elif isinstance(request.prompt_token_ids[0], int):
|
||||
request_prompt_ids = [request.prompt_token_ids]
|
||||
else:
|
||||
raise ValueError(
|
||||
"If prompt_token_ids is provided, its type should be one of: list[int], list[list[int]]"
|
||||
)
|
||||
# reset `prompt_token_ids` to avoid data processor directly using it; let data processor fill it
|
||||
request.prompt_token_ids = None
|
||||
else:
|
||||
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
|
||||
if isinstance(request.prompt, str):
|
||||
request_prompts = [request.prompt]
|
||||
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
|
||||
request_prompt_ids = [request.prompt]
|
||||
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
|
||||
request_prompts = request.prompt
|
||||
elif isinstance(request.prompt, list):
|
||||
for item in request.prompt:
|
||||
if isinstance(item, list) and all(isinstance(x, int) for x in item):
|
||||
continue
|
||||
else:
|
||||
raise ValueError("If prompt is a list, each item type must be one of: str, list[int]")
|
||||
request_prompt_ids = request.prompt
|
||||
else:
|
||||
raise ValueError("Prompt type must be one of: str, list[str], list[int], list[list[int]]")
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}"
|
||||
api_server_logger.error(error_msg)
|
||||
@@ -107,9 +122,9 @@ class OpenAIServingCompletion:
|
||||
|
||||
if request_prompt_ids is not None:
|
||||
request_prompts = request_prompt_ids
|
||||
num_choices = len(request_prompts)
|
||||
|
||||
api_server_logger.info(f"start inference for request {num_choices}")
|
||||
num_choices = len(request_prompts)
|
||||
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
|
||||
prompt_batched_token_ids = []
|
||||
text_after_process_list = []
|
||||
try:
|
||||
@@ -131,7 +146,7 @@ class OpenAIServingCompletion:
|
||||
request_id_idx = f"{request_id}-{idx}"
|
||||
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
|
||||
current_req_dict["arrival_time"] = time.time()
|
||||
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
|
||||
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize
|
||||
if isinstance(prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = prompt_token_ids.tolist()
|
||||
text_after_process_list.append(current_req_dict.get("text_after_process"))
|
||||
|
Reference in New Issue
Block a user