mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 03:46:40 +08:00 
			
		
		
		
	Completion add raw_prediction/text_after_process (#3356)
This commit is contained in:
		| @@ -126,6 +126,8 @@ class ChatMessage(BaseModel): | ||||
|     tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None | ||||
|     prompt_token_ids: Optional[List[int]] = None | ||||
|     completion_token_ids: Optional[List[int]] = None | ||||
|     text_after_process: Optional[str] = None | ||||
|     raw_prediction: Optional[str] = None | ||||
|  | ||||
|  | ||||
| class ChatCompletionResponseChoice(BaseModel): | ||||
| @@ -183,6 +185,8 @@ class DeltaMessage(BaseModel): | ||||
|     completion_token_ids: Optional[List[int]] = None | ||||
|     reasoning_content: Optional[str] = None | ||||
|     tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None | ||||
|     text_after_process: Optional[str] = None | ||||
|     raw_prediction: Optional[str] = None | ||||
|  | ||||
|  | ||||
| class ChatCompletionResponseStreamChoice(BaseModel): | ||||
| @@ -219,6 +223,8 @@ class CompletionResponseChoice(BaseModel): | ||||
|     text: str | ||||
|     prompt_token_ids: Optional[List[int]] = None | ||||
|     completion_token_ids: Optional[List[int]] = None | ||||
|     text_after_process: Optional[str] = None | ||||
|     raw_prediction: Optional[str] = None | ||||
|     arrival_time: Optional[float] = None | ||||
|     logprobs: Optional[CompletionLogprobs] = None | ||||
|     reasoning_content: Optional[str] = None | ||||
| @@ -261,6 +267,8 @@ class CompletionResponseStreamChoice(BaseModel): | ||||
|     logprobs: Optional[CompletionLogprobs] = None | ||||
|     prompt_token_ids: Optional[List[int]] = None | ||||
|     completion_token_ids: Optional[List[int]] = None | ||||
|     text_after_process: Optional[str] = None | ||||
|     raw_prediction: Optional[str] = None | ||||
|     reasoning_content: Optional[str] = None | ||||
|     finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None | ||||
|     tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None | ||||
|   | ||||
| @@ -83,11 +83,12 @@ class OpenAIServingChat: | ||||
|         else: | ||||
|             request_id = f"chatcmpl-{uuid.uuid4()}" | ||||
|         api_server_logger.info(f"create chat completion request: {request_id}") | ||||
|  | ||||
|         text_after_process = None | ||||
|         try: | ||||
|             current_req_dict = request.to_dict_for_infer(request_id) | ||||
|             current_req_dict["arrival_time"] = time.time() | ||||
|             prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) | ||||
|             text_after_process = current_req_dict.get("text_after_process") | ||||
|             if isinstance(prompt_token_ids, np.ndarray): | ||||
|                 prompt_token_ids = prompt_token_ids.tolist() | ||||
|         except Exception as e: | ||||
| @@ -104,10 +105,14 @@ class OpenAIServingChat: | ||||
|             return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") | ||||
|  | ||||
|         if request.stream: | ||||
|             return self.chat_completion_stream_generator(request, request_id, request.model, prompt_token_ids) | ||||
|             return self.chat_completion_stream_generator( | ||||
|                 request, request_id, request.model, prompt_token_ids, text_after_process | ||||
|             ) | ||||
|         else: | ||||
|             try: | ||||
|                 return await self.chat_completion_full_generator(request, request_id, request.model, prompt_token_ids) | ||||
|                 return await self.chat_completion_full_generator( | ||||
|                     request, request_id, request.model, prompt_token_ids, text_after_process | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 return ErrorResponse(code=400, message=str(e)) | ||||
|  | ||||
| @@ -124,6 +129,7 @@ class OpenAIServingChat: | ||||
|         request_id: str, | ||||
|         model_name: str, | ||||
|         prompt_token_ids: list(), | ||||
|         text_after_process: str, | ||||
|     ): | ||||
|         """ | ||||
|         Streaming chat completion generator. | ||||
| @@ -216,6 +222,7 @@ class OpenAIServingChat: | ||||
|                             ) | ||||
|                             if request.return_token_ids: | ||||
|                                 choice.delta.prompt_token_ids = list(prompt_token_ids) | ||||
|                                 choice.delta.text_after_process = text_after_process | ||||
|                             chunk = ChatCompletionStreamResponse( | ||||
|                                 id=request_id, | ||||
|                                 object=chunk_object_type, | ||||
| @@ -279,6 +286,7 @@ class OpenAIServingChat: | ||||
|  | ||||
|                     if request.return_token_ids: | ||||
|                         choice.delta.completion_token_ids = list(output["token_ids"]) | ||||
|                         choice.delta.raw_prediction = output.get("raw_prediction") | ||||
|                     if include_continuous_usage: | ||||
|                         chunk.usage = UsageInfo( | ||||
|                             prompt_tokens=num_prompt_tokens, | ||||
| @@ -329,6 +337,7 @@ class OpenAIServingChat: | ||||
|         request_id: str, | ||||
|         model_name: str, | ||||
|         prompt_token_ids: list(), | ||||
|         text_after_process: str, | ||||
|     ): | ||||
|         """ | ||||
|         Full chat completion generator. | ||||
| @@ -406,6 +415,8 @@ class OpenAIServingChat: | ||||
|             tool_calls=output.get("tool_call_content"), | ||||
|             prompt_token_ids=prompt_token_ids if request.return_token_ids else None, | ||||
|             completion_token_ids=completion_token_ids if request.return_token_ids else None, | ||||
|             text_after_process=text_after_process if request.return_token_ids else None, | ||||
|             raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, | ||||
|         ) | ||||
|         logprobs_full_res = None | ||||
|         if logprob_contents: | ||||
|   | ||||
| @@ -100,6 +100,7 @@ class OpenAIServingCompletion: | ||||
|  | ||||
|         api_server_logger.info(f"start inference for request {num_choices}") | ||||
|         prompt_batched_token_ids = [] | ||||
|         text_after_process_list = [] | ||||
|         try: | ||||
|             for idx, prompt in enumerate(request_prompts): | ||||
|                 request_id_idx = f"{request_id}-{idx}" | ||||
| @@ -109,6 +110,7 @@ class OpenAIServingCompletion: | ||||
|                     prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) | ||||
|                     if isinstance(prompt_token_ids, np.ndarray): | ||||
|                         prompt_token_ids = prompt_token_ids.tolist() | ||||
|                     text_after_process_list.append(current_req_dict.get("text_after_process")) | ||||
|                     prompt_batched_token_ids.append(prompt_token_ids) | ||||
|                 except Exception as e: | ||||
|                     return ErrorResponse(message=str(e), code=400) | ||||
| @@ -131,6 +133,7 @@ class OpenAIServingCompletion: | ||||
|                     created_time=created_time, | ||||
|                     model_name=request.model, | ||||
|                     prompt_batched_token_ids=prompt_batched_token_ids, | ||||
|                     text_after_process_list=text_after_process_list, | ||||
|                 ) | ||||
|             else: | ||||
|                 try: | ||||
| @@ -141,6 +144,7 @@ class OpenAIServingCompletion: | ||||
|                         created_time=created_time, | ||||
|                         model_name=request.model, | ||||
|                         prompt_batched_token_ids=prompt_batched_token_ids, | ||||
|                         text_after_process_list=text_after_process_list, | ||||
|                     ) | ||||
|                 except Exception as e: | ||||
|                     return ErrorResponse(code=400, message=str(e)) | ||||
| @@ -156,6 +160,7 @@ class OpenAIServingCompletion: | ||||
|         created_time: int, | ||||
|         model_name: str, | ||||
|         prompt_batched_token_ids: list(), | ||||
|         text_after_process_list: list(), | ||||
|     ): | ||||
|         """ | ||||
|         Process the full completion request with multiple choices. | ||||
| @@ -225,6 +230,7 @@ class OpenAIServingCompletion: | ||||
|                 model_name=model_name, | ||||
|                 prompt_batched_token_ids=prompt_batched_token_ids, | ||||
|                 completion_batched_token_ids=completion_batched_token_ids, | ||||
|                 text_after_process_list=text_after_process_list, | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             api_server_logger.error(f"Error in completion_full_generator: {e}", exc_info=True) | ||||
| @@ -251,6 +257,7 @@ class OpenAIServingCompletion: | ||||
|         created_time: int, | ||||
|         model_name: str, | ||||
|         prompt_batched_token_ids: list(), | ||||
|         text_after_process_list: list(), | ||||
|     ): | ||||
|         """ | ||||
|         Process the stream completion request. | ||||
| @@ -309,6 +316,7 @@ class OpenAIServingCompletion: | ||||
|                                         index=idx, | ||||
|                                         text="", | ||||
|                                         prompt_token_ids=list(prompt_batched_token_ids[idx]), | ||||
|                                         text_after_process=text_after_process_list[idx], | ||||
|                                         completion_token_ids=None, | ||||
|                                     ) | ||||
|                                 ], | ||||
| @@ -337,6 +345,7 @@ class OpenAIServingCompletion: | ||||
|                             text=output["text"], | ||||
|                             prompt_token_ids=None, | ||||
|                             completion_token_ids=output.get("token_ids") if request.return_token_ids else None, | ||||
|                             raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, | ||||
|                             tool_calls=output.get("tool_call_content"), | ||||
|                             reasoning_content=output.get("reasoning_content"), | ||||
|                             arrival_time=arrival_time, | ||||
| @@ -398,6 +407,7 @@ class OpenAIServingCompletion: | ||||
|         model_name: str, | ||||
|         prompt_batched_token_ids: list(), | ||||
|         completion_batched_token_ids: list(), | ||||
|         text_after_process_list: list(), | ||||
|     ) -> CompletionResponse: | ||||
|         choices: List[CompletionResponseChoice] = [] | ||||
|         num_prompt_tokens = 0 | ||||
| @@ -444,6 +454,8 @@ class OpenAIServingCompletion: | ||||
|                 text=output_text, | ||||
|                 prompt_token_ids=prompt_token_ids if request.return_token_ids else None, | ||||
|                 completion_token_ids=completion_token_ids if request.return_token_ids else None, | ||||
|                 raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, | ||||
|                 text_after_process=text_after_process_list[idx] if request.return_token_ids else None, | ||||
|                 reasoning_content=output.get("reasoning_content"), | ||||
|                 tool_calls=output.get("tool_call_content"), | ||||
|                 logprobs=aggregated_logprobs, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 memoryCoderC
					memoryCoderC