[feat] completion api supports passing input token ids in either prompt or prompt_token_ids (#3311)

* [feat] completion api supports passing input token ids in either `prompt` or `prompt_token_ids`

* [fix] update comment

* [fix] fix type error

* [test] add a unittest file for serving api test

* [test] try to fix ci error

* [chore] rename test function names

* [test] try to fix ci error

* [test] try to fix ci error

* [test] add tests for qwen
This commit is contained in:
李泳桦
2025-08-29 14:19:42 +08:00
committed by GitHub
parent 17b414c2df
commit 88297240e7
6 changed files with 343 additions and 70 deletions

View File

@@ -81,25 +81,40 @@ class OpenAIServingCompletion:
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
else:
request_id = f"cmpl-{uuid.uuid4()}"
api_server_logger.info(f"initialize request {request_id}")
api_server_logger.info(f"Initialize request {request_id}: {request}")
request_prompt_ids = None
request_prompts = None
# Handle prompt and prompt_token_ids
try:
if isinstance(request.prompt, str):
request_prompts = [request.prompt]
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
request_prompt_ids = [request.prompt]
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
request_prompts = request.prompt
elif isinstance(request.prompt, list):
for item in request.prompt:
if isinstance(item, list) and all(isinstance(x, int) for x in item):
continue
else:
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
request_prompt_ids = request.prompt
if request.prompt_token_ids is not None: # let `prompt_token_ids` support batch inference
assert len(request.prompt_token_ids) > 0, "prompt_token_ids should not be an empty list"
if isinstance(request.prompt_token_ids[0], list):
request_prompt_ids = request.prompt_token_ids
elif isinstance(request.prompt_token_ids[0], int):
request_prompt_ids = [request.prompt_token_ids]
else:
raise ValueError(
"If prompt_token_ids is provided, its type should be one of: list[int], list[list[int]]"
)
# reset `prompt_token_ids` to avoid data processor directly using it; let data processor fill it
request.prompt_token_ids = None
else:
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
if isinstance(request.prompt, str):
request_prompts = [request.prompt]
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
request_prompt_ids = [request.prompt]
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
request_prompts = request.prompt
elif isinstance(request.prompt, list):
for item in request.prompt:
if isinstance(item, list) and all(isinstance(x, int) for x in item):
continue
else:
raise ValueError("If prompt is a list, each item type must be one of: str, list[int]")
request_prompt_ids = request.prompt
else:
raise ValueError("Prompt type must be one of: str, list[str], list[int], list[list[int]]")
except Exception as e:
error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
@@ -107,9 +122,9 @@ class OpenAIServingCompletion:
if request_prompt_ids is not None:
request_prompts = request_prompt_ids
num_choices = len(request_prompts)
api_server_logger.info(f"start inference for request {num_choices}")
num_choices = len(request_prompts)
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
prompt_batched_token_ids = []
text_after_process_list = []
try:
@@ -131,7 +146,7 @@ class OpenAIServingCompletion:
request_id_idx = f"{request_id}-{idx}"
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
current_req_dict["arrival_time"] = time.time()
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
text_after_process_list.append(current_req_dict.get("text_after_process"))