[feat] completion api supports passing input token ids in either prompt or prompt_token_ids (#3311)

* [feat] completion api supports passing input token ids in either `prompt` or `prompt_token_ids`

* [fix] update comment

* [fix] fix type error

* [test] add a unittest file for serving api test

* [test] try to fix ci error

* [chore] rename test function names

* [test] try to fix ci error

* [test] try to fix ci error

* [test] add tests for qwen
This commit is contained in:
李泳桦
2025-08-29 14:19:42 +08:00
committed by GitHub
parent 17b414c2df
commit 88297240e7
6 changed files with 343 additions and 70 deletions

View File

@@ -204,25 +204,37 @@ class DataProcessor(BaseDataProcessor):
bool: Whether preprocessing is successful
str: error message
"""
data_processor_logger.info(f"Start processing request: {request}")
request.chat_template = kwargs.get("chat_template")
request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)
# processing bad_words
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
prompt = request.prompt
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request.prompt_token_ids = prompt
else:
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
elif request.messages is not None:
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
@@ -239,19 +251,22 @@ class DataProcessor(BaseDataProcessor):
request.prompt_token_ids = self.messages2ids(task)
else:
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
if len(request.prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
if request.get("max_tokens") is None:
request.set(
"max_tokens",
max(1, max_model_len - len(request.prompt_token_ids)),
)
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling
request.set("temperature", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
data_processor_logger.info(f"Processed request {request}")
data_processor_logger.info(f"Processed request: {request}")
return request
def process_request_dict(self, request, max_model_len=None, **kwargs):
@@ -265,6 +280,7 @@ class DataProcessor(BaseDataProcessor):
bool: Whether preprocessing is successful
str: error message
"""
data_processor_logger.info(f"Start processing request dict: {request}")
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
@@ -283,13 +299,18 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
data_processor_logger.info(f"Processing request {request}")
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
if "prompt" in request:
request["text_after_process"] = request["prompt"]
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
elif "messages" in request:
if request.get("prompt"):
prompt = request.get("prompt")
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request["prompt_token_ids"] = prompt
else:
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
elif request.get("messages"):
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
chat_template_kwargs = request.get("chat_template_kwargs")
@@ -304,8 +325,13 @@ class DataProcessor(BaseDataProcessor):
request["prompt_token_ids"] = self.messages2ids(request)
else:
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
if len(request["prompt_token_ids"]) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
if request.get("temperature") < _SAMPLING_EPS:
@@ -313,7 +339,8 @@ class DataProcessor(BaseDataProcessor):
request["temperature"] = 1
if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
data_processor_logger.info(f"Processed request {request}")
data_processor_logger.info(f"Processed request dict: {request}")
return request
def process_logprob_response(self, token_ids, **kwargs):