mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[feat] completion api supports passing input token ids in either prompt
or prompt_token_ids
(#3311)
* [feat] completion api supports passing input token ids in either `prompt` or `prompt_token_ids` * [fix] update comment * [fix] fix type error * [test] add a unittest file for serving api test * [test] try to fix ci error * [chore] rename test function names * [test] try to fix ci error * [test] try to fix ci error * [test] add tests for qwen
This commit is contained in:
@@ -204,25 +204,37 @@ class DataProcessor(BaseDataProcessor):
|
||||
bool: Whether preprocessing is successful
|
||||
str: error message
|
||||
"""
|
||||
data_processor_logger.info(f"Start processing request: {request}")
|
||||
request.chat_template = kwargs.get("chat_template")
|
||||
request = self._apply_default_parameters(request)
|
||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||
request.eos_token_ids = self.eos_token_ids
|
||||
|
||||
# processing stop_sequences
|
||||
stop_sequences = request.get("stop", [])
|
||||
if stop_sequences is not None and len(stop_sequences) != 0:
|
||||
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
|
||||
request.set("stop_token_ids", stop_seqs)
|
||||
request.set("stop_seqs_len", stop_seqs_len)
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
# processing prompt_token_ids
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
|
||||
prompt = request.prompt
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request.prompt_token_ids = prompt
|
||||
else:
|
||||
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
|
||||
elif request.messages is not None:
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
@@ -239,19 +251,22 @@ class DataProcessor(BaseDataProcessor):
|
||||
request.prompt_token_ids = self.messages2ids(task)
|
||||
else:
|
||||
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||
|
||||
if len(request.prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
if request.get("max_tokens") is None:
|
||||
request.set(
|
||||
"max_tokens",
|
||||
max(1, max_model_len - len(request.prompt_token_ids)),
|
||||
)
|
||||
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature is equivalent to greedy sampling
|
||||
request.set("temperature", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
data_processor_logger.info(f"Processed request: {request}")
|
||||
return request
|
||||
|
||||
def process_request_dict(self, request, max_model_len=None, **kwargs):
|
||||
@@ -265,6 +280,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
bool: Whether preprocessing is successful
|
||||
str: error message
|
||||
"""
|
||||
data_processor_logger.info(f"Start processing request dict: {request}")
|
||||
request = self._apply_default_parameters(request)
|
||||
if not request.get("eos_token_ids"):
|
||||
request["eos_token_ids"] = self.eos_token_ids
|
||||
@@ -283,13 +299,18 @@ class DataProcessor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
data_processor_logger.info(f"Processing request {request}")
|
||||
# processing prompt_token_ids
|
||||
if not request.get("prompt_token_ids"):
|
||||
if "prompt" in request:
|
||||
request["text_after_process"] = request["prompt"]
|
||||
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
|
||||
elif "messages" in request:
|
||||
if request.get("prompt"):
|
||||
prompt = request.get("prompt")
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request["prompt_token_ids"] = prompt
|
||||
else:
|
||||
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
|
||||
elif request.get("messages"):
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||
@@ -304,8 +325,13 @@ class DataProcessor(BaseDataProcessor):
|
||||
request["prompt_token_ids"] = self.messages2ids(request)
|
||||
else:
|
||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||
|
||||
if len(request["prompt_token_ids"]) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
@@ -313,7 +339,8 @@ class DataProcessor(BaseDataProcessor):
|
||||
request["temperature"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
data_processor_logger.info(f"Processed request dict: {request}")
|
||||
return request
|
||||
|
||||
def process_logprob_response(self, token_ids, **kwargs):
|
||||
|
Reference in New Issue
Block a user