diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 0703274e2..66076bedf 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -372,19 +372,22 @@ class CompletionRequest(BaseModel): req_dict = {} if request_id is not None: req_dict["request_id"] = request_id + + # parse request model into dict, priority: request > extra_body > suffix for key, value in self.dict().items(): if value is not None: req_dict[key] = value + if self.extra_body is not None: + for key, value in self.extra_body.items(): + req_dict.setdefault(key, value) if self.suffix is not None: for key, value in self.suffix.items(): - req_dict[key] = value + req_dict.setdefault(key, value) + if prompt is not None: req_dict["prompt"] = prompt - if self.prompt_token_ids is not None or ( - self.extra_body is not None and self.extra_body.get("prompt_token_ids") is not None - ): - req_dict["prompt_token_ids"] = self.prompt_token_ids + if "prompt_token_ids" in req_dict: if "prompt" in req_dict: del req_dict["prompt"] else: @@ -475,7 +478,7 @@ class ChatCompletionRequest(BaseModel): top_p: Optional[float] = None top_k: Optional[int] = None min_p: Optional[float] = None - user: Optional[str] = None + user: Optional[str] = None metadata: Optional[dict] = None extra_body: Optional[dict] = None return_token_ids: Optional[bool] = False @@ -508,21 +511,21 @@ class ChatCompletionRequest(BaseModel): req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens req_dict["logprobs"] = self.top_logprobs if self.logprobs else None + # parse request model into dict, priority: request > extra_body > metadata + for key, value in self.dict().items(): + if value is not None: + req_dict[key] = value + if self.extra_body is not None: + for key, value in self.extra_body.items(): + req_dict.setdefault(key, value) if self.metadata is not None: assert ( "raw_request" not in self.metadata ), "The parameter `raw_request` is not supported now, please use completion api instead." for key, value in self.metadata.items(): - req_dict[key] = value + req_dict.setdefault(key, value) - for key, value in self.dict().items(): - if value is not None: - req_dict[key] = value - - if self.prompt_token_ids is not None or ( - self.extra_body is not None and self.extra_body.get("prompt_token_ids") is not None - ): - req_dict["prompt_token_ids"] = self.prompt_token_ids + if "prompt_token_ids" in req_dict: if "messages" in req_dict: del req_dict["messages"] else: diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index bc4fd679e..611b0cb8d 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -330,6 +330,7 @@ class OpenAIServingChat: previous_num_tokens = 0 current_waiting_time = 0 logprob_contents = [] + completion_token_ids = [] while True: try: raw_data = await asyncio.wait_for(dealer.read(), timeout=10) @@ -361,6 +362,7 @@ class OpenAIServingChat: ) # api_server_logger.debug(f"Client {request_id} received: {data}") previous_num_tokens += len(data["outputs"]["token_ids"]) + completion_token_ids.extend(data["outputs"]["token_ids"]) # The logprob for handling the response output = data["outputs"] raw_top_logprobs = output["top_logprobs"] @@ -394,7 +396,7 @@ class OpenAIServingChat: reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_call_content"), prompt_token_ids=prompt_token_ids if enable_return_token_ids else None, - completion_token_ids=output.get("token_ids") if enable_return_token_ids else None, + completion_token_ids=completion_token_ids if enable_return_token_ids else None, ) logprobs_full_res = None if logprob_contents: diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 5ad554566..34f712409 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -151,6 +151,7 @@ class OpenAIServingCompletion: valid_results = [dict()] * num_choices output_tokens = [0] * num_choices + completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 while num_choices > 0: try: @@ -174,6 +175,7 @@ class OpenAIServingCompletion: self.engine_client.data_processor.process_response_dict(data, stream=False) output_tokens[rid] += len(data["outputs"]["token_ids"]) + completion_batched_token_ids[rid].extend(data["outputs"]["token_ids"]) if data.get("finished", False): data["output_token_ids"] = output_tokens[rid] valid_results[rid] = data @@ -187,6 +189,7 @@ class OpenAIServingCompletion: created_time=created_time, model_name=model_name, prompt_batched_token_ids=prompt_batched_token_ids, + completion_batched_token_ids=completion_batched_token_ids, ) except Exception as e: api_server_logger.error(f"Error in completion_full_generator: {e}", exc_info=True) @@ -341,6 +344,7 @@ class OpenAIServingCompletion: created_time: int, model_name: str, prompt_batched_token_ids: list(), + completion_batched_token_ids: list() ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -352,6 +356,7 @@ class OpenAIServingCompletion: prompt_token_ids = prompt_batched_token_ids[idx] assert prompt_token_ids is not None prompt_text = final_res["prompt"] + completion_token_ids = completion_batched_token_ids[idx] output = final_res["outputs"] if request.echo: @@ -371,7 +376,7 @@ class OpenAIServingCompletion: index=len(choices), text=output_text, prompt_token_ids=prompt_token_ids if enable_return_token_ids else None, - completion_token_ids=output["token_ids"] if enable_return_token_ids else None, + completion_token_ids=completion_token_ids if enable_return_token_ids else None, reasoning_content=output.get('reasoning_content'), tool_calls=output.get("tool_call_content"), logprobs=None, diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 6bb7b5011..872b52e08 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -138,14 +138,15 @@ class ErnieProcessor(BaseDataProcessor): request = self._apply_default_parameters(request) if not request.get("eos_token_ids"): request["eos_token_ids"] = self.eos_token_ids - # 处理stop_sequences + + # processing stop_sequences stop_sequences = request.get("stop", []) if stop_sequences: stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) request["stop_token_ids"] = stop_seqs request["stop_seqs_len"] = stop_seqs_len - # 处理prompt_token_ids + # processing prompt_token_ids if not request.get("prompt_token_ids"): if request.get("prompt") is None and request.get("messages") is None: raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}") @@ -161,7 +162,7 @@ class ErnieProcessor(BaseDataProcessor): else: request["prompt_token_ids"] = self.messages2ids(request) - # 截断超过长度限制的prompt + # truncate prompts that exceed the length limit if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len: request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1] if request.get("max_tokens") is None: