mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
fix response processsors (#3826)
* fix response processsors * fix ci * fix ut --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -286,7 +286,8 @@ class CompletionOutput:
|
||||
|
||||
index: int
|
||||
send_idx: int
|
||||
token_ids: list[int]
|
||||
token_ids: list[Any]
|
||||
decode_type: int = 0
|
||||
logprob: Optional[float] = None
|
||||
top_logprobs: Optional[LogprobsLists] = None
|
||||
logprobs: Optional[SampleLogprobs] = None
|
||||
|
@@ -17,6 +17,7 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
|
||||
from fastdeploy.utils import api_server_logger
|
||||
|
||||
|
||||
class ChatResponseProcessor:
|
||||
@@ -41,6 +42,8 @@ class ChatResponseProcessor:
|
||||
self.eos_token_id = eos_token_id
|
||||
if decoder_base_url is not None:
|
||||
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
|
||||
else:
|
||||
self.decoder_client = None
|
||||
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
|
||||
self._end_image_code_request_output: Optional[Any] = None
|
||||
self._multipart_buffer = []
|
||||
@@ -74,6 +77,7 @@ class ChatResponseProcessor:
|
||||
include_stop_str_in_output: Whether or not to include stop strings in the output.
|
||||
"""
|
||||
for request_output in request_outputs:
|
||||
api_server_logger.debug(f"request_output {request_output}")
|
||||
if not self.enable_mm_output:
|
||||
yield self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
@@ -112,7 +116,7 @@ class ChatResponseProcessor:
|
||||
yield request_output
|
||||
|
||||
elif decode_type == 1:
|
||||
self._mm_buffer.extend(token_ids)
|
||||
self._mm_buffer.append(token_ids)
|
||||
self._end_image_code_request_output = request_output
|
||||
else:
|
||||
self.accumulate_token_ids(request_output)
|
||||
|
@@ -80,7 +80,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
image_part = results[1]["outputs"]["multipart"][0]
|
||||
self.assertEqual(image_part["type"], "image")
|
||||
self.assertEqual(image_part["url"], "http://image.url/test.png")
|
||||
self.assertEqual(results[1]["outputs"]["token_ids"], [[11, 22]])
|
||||
self.assertEqual(results[1]["outputs"]["token_ids"], [[[11, 22]]])
|
||||
|
||||
# 第三个 yield:text
|
||||
text_part = results[2]["outputs"]["multipart"][0]
|
||||
@@ -99,7 +99,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
|
||||
self.assertEqual(results, [])
|
||||
self.assertEqual(self.processor_mm._mm_buffer, [[33, 44]])
|
||||
self.assertEqual(self.processor_mm._mm_buffer, [[[33, 44]]])
|
||||
|
||||
async def test_non_streaming_accumulate_and_emit(self):
|
||||
"""非流式模式:等 eos_token_id 才输出 multipart(text+image)"""
|
||||
|
Reference in New Issue
Block a user