diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 43766f27f..e2a1158d3 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -234,6 +234,15 @@ class OpenAIServingCompletion: if dealer is not None: dealer.close() + def calc_finish_reason(self, max_tokens, token_num, output): + if max_tokens is None or token_num != max_tokens: + if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls": + return "tool_calls" + else: + return "stop" + else: + return "length" + async def completion_stream_generator( self, request: CompletionRequest, @@ -334,19 +343,13 @@ class OpenAIServingCompletion: logprobs=logprobs_res, ) ) - if res["finished"]: - if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens: - chunk.choices[0].finish_reason = "stop" - if ( - self.engine_client.reasoning_parser == "ernie_x1" - and output.get("finish_reason", "") == "tool_calls" - ): - chunk.choices[0].finish_reason = "tool_calls" - else: - chunk.choices[0].finish_reason = "length" - output_tokens[idx] += 1 + if res["finished"]: + choices[-1].finish_reason = self.calc_finish_reason( + request.max_tokens, output_tokens[idx], output + ) + if len(choices) == max_streaming_response_tokens or res["finished"]: chunk = CompletionStreamResponse( id=request_id, @@ -433,6 +436,8 @@ class OpenAIServingCompletion: token_ids = output["token_ids"] output_text = output["text"] + finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output) + choice_data = CompletionResponseChoice( token_ids=token_ids, index=len(choices), @@ -442,7 +447,7 @@ class OpenAIServingCompletion: reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_call_content"), logprobs=aggregated_logprobs, - finish_reason=None, + finish_reason=finish_reason, ) choices.append(choice_data) diff --git a/test/entrypoints/openai/test_serving_completion.py b/test/entrypoints/openai/test_serving_completion.py new file mode 100644 index 000000000..8d1a4eb66 --- /dev/null +++ b/test/entrypoints/openai/test_serving_completion.py @@ -0,0 +1,111 @@ +import unittest +from typing import List +from unittest.mock import Mock + +from fastdeploy.entrypoints.openai.serving_completion import ( + CompletionRequest, + OpenAIServingCompletion, + RequestOutput, +) + + +class TestOpenAIServingCompletion(unittest.TestCase): + + def test_calc_finish_reason_tool_calls(self): + # 创建一个模拟的engine_client,并设置reasoning_parser为"ernie_x1" + engine_client = Mock() + engine_client.reasoning_parser = "ernie_x1" + # 创建一个OpenAIServingCompletion实例 + serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360) + # 创建一个模拟的output,并设置finish_reason为"tool_calls" + output = {"finish_reason": "tool_calls"} + # 调用calc_finish_reason方法 + result = serving_completion.calc_finish_reason(None, 100, output) + # 断言结果为"tool_calls" + assert result == "tool_calls" + + def test_calc_finish_reason_stop(self): + # 创建一个模拟的engine_client,并设置reasoning_parser为"ernie_x1" + engine_client = Mock() + engine_client.reasoning_parser = "ernie_x1" + # 创建一个OpenAIServingCompletion实例 + serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360) + # 创建一个模拟的output,并设置finish_reason为其他值 + output = {"finish_reason": "other_reason"} + # 调用calc_finish_reason方法 + result = serving_completion.calc_finish_reason(None, 100, output) + # 断言结果为"stop" + assert result == "stop" + + def test_calc_finish_reason_length(self): + # 创建一个模拟的engine_client + engine_client = Mock() + # 创建一个OpenAIServingCompletion实例 + serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360) + # 创建一个模拟的output + output = {} + # 调用calc_finish_reason方法 + result = serving_completion.calc_finish_reason(100, 100, output) + # 断言结果为"length" + assert result == "length" + + def test_request_output_to_completion_response(self): + engine_client = Mock() + # 创建一个OpenAIServingCompletion实例 + openai_serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360) + final_res_batch: List[RequestOutput] = [ + { + "prompt": "Hello, world!", + "outputs": { + "token_ids": [1, 2, 3], + "text": " world!", + "top_logprobs": { + "a": 0.1, + "b": 0.2, + }, + }, + "output_token_ids": 3, + }, + { + "prompt": "Hello, world!", + "outputs": { + "token_ids": [4, 5, 6], + "text": " world!", + "top_logprobs": { + "a": 0.3, + "b": 0.4, + }, + }, + "output_token_ids": 3, + }, + ] + + request: CompletionRequest = Mock() + request_id = "test_request_id" + created_time = 1655136000 + model_name = "test_model" + prompt_batched_token_ids = [[1, 2, 3], [4, 5, 6]] + completion_batched_token_ids = [[7, 8, 9], [10, 11, 12]] + + completion_response = openai_serving_completion.request_output_to_completion_response( + final_res_batch=final_res_batch, + request=request, + request_id=request_id, + created_time=created_time, + model_name=model_name, + prompt_batched_token_ids=prompt_batched_token_ids, + completion_batched_token_ids=completion_batched_token_ids, + ) + + assert completion_response.id == request_id + assert completion_response.created == created_time + assert completion_response.model == model_name + assert len(completion_response.choices) == 2 + + # 验证 choices 的 text 属性 + assert completion_response.choices[0].text == "Hello, world! world!" + assert completion_response.choices[1].text == "Hello, world! world!" + + +if __name__ == "__main__": + unittest.main()