mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix response processsors (#3826)
* fix response processsors * fix ci * fix ut --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -286,7 +286,8 @@ class CompletionOutput:
|
|||||||
|
|
||||||
index: int
|
index: int
|
||||||
send_idx: int
|
send_idx: int
|
||||||
token_ids: list[int]
|
token_ids: list[Any]
|
||||||
|
decode_type: int = 0
|
||||||
logprob: Optional[float] = None
|
logprob: Optional[float] = None
|
||||||
top_logprobs: Optional[LogprobsLists] = None
|
top_logprobs: Optional[LogprobsLists] = None
|
||||||
logprobs: Optional[SampleLogprobs] = None
|
logprobs: Optional[SampleLogprobs] = None
|
||||||
|
@@ -17,6 +17,7 @@
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
|
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
|
||||||
|
from fastdeploy.utils import api_server_logger
|
||||||
|
|
||||||
|
|
||||||
class ChatResponseProcessor:
|
class ChatResponseProcessor:
|
||||||
@@ -41,6 +42,8 @@ class ChatResponseProcessor:
|
|||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
if decoder_base_url is not None:
|
if decoder_base_url is not None:
|
||||||
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
|
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
|
||||||
|
else:
|
||||||
|
self.decoder_client = None
|
||||||
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
|
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
|
||||||
self._end_image_code_request_output: Optional[Any] = None
|
self._end_image_code_request_output: Optional[Any] = None
|
||||||
self._multipart_buffer = []
|
self._multipart_buffer = []
|
||||||
@@ -74,6 +77,7 @@ class ChatResponseProcessor:
|
|||||||
include_stop_str_in_output: Whether or not to include stop strings in the output.
|
include_stop_str_in_output: Whether or not to include stop strings in the output.
|
||||||
"""
|
"""
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
|
api_server_logger.debug(f"request_output {request_output}")
|
||||||
if not self.enable_mm_output:
|
if not self.enable_mm_output:
|
||||||
yield self.data_processor.process_response_dict(
|
yield self.data_processor.process_response_dict(
|
||||||
response_dict=request_output,
|
response_dict=request_output,
|
||||||
@@ -112,7 +116,7 @@ class ChatResponseProcessor:
|
|||||||
yield request_output
|
yield request_output
|
||||||
|
|
||||||
elif decode_type == 1:
|
elif decode_type == 1:
|
||||||
self._mm_buffer.extend(token_ids)
|
self._mm_buffer.append(token_ids)
|
||||||
self._end_image_code_request_output = request_output
|
self._end_image_code_request_output = request_output
|
||||||
else:
|
else:
|
||||||
self.accumulate_token_ids(request_output)
|
self.accumulate_token_ids(request_output)
|
||||||
|
@@ -80,7 +80,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
|||||||
image_part = results[1]["outputs"]["multipart"][0]
|
image_part = results[1]["outputs"]["multipart"][0]
|
||||||
self.assertEqual(image_part["type"], "image")
|
self.assertEqual(image_part["type"], "image")
|
||||||
self.assertEqual(image_part["url"], "http://image.url/test.png")
|
self.assertEqual(image_part["url"], "http://image.url/test.png")
|
||||||
self.assertEqual(results[1]["outputs"]["token_ids"], [[11, 22]])
|
self.assertEqual(results[1]["outputs"]["token_ids"], [[[11, 22]]])
|
||||||
|
|
||||||
# 第三个 yield:text
|
# 第三个 yield:text
|
||||||
text_part = results[2]["outputs"]["multipart"][0]
|
text_part = results[2]["outputs"]["multipart"][0]
|
||||||
@@ -99,7 +99,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
self.assertEqual(self.processor_mm._mm_buffer, [[33, 44]])
|
self.assertEqual(self.processor_mm._mm_buffer, [[[33, 44]]])
|
||||||
|
|
||||||
async def test_non_streaming_accumulate_and_emit(self):
|
async def test_non_streaming_accumulate_and_emit(self):
|
||||||
"""非流式模式:等 eos_token_id 才输出 multipart(text+image)"""
|
"""非流式模式:等 eos_token_id 才输出 multipart(text+image)"""
|
||||||
|
Reference in New Issue
Block a user