mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support audio tts (#5333)
This commit is contained in:
@@ -416,6 +416,7 @@ class CompletionOutput:
|
||||
f"send_idx={self.send_idx}, "
|
||||
f"text={self.text!r}, "
|
||||
f"token_ids={self.token_ids}, "
|
||||
f"decode_type={self.decode_type}, "
|
||||
f"draft_token_ids={self.draft_token_ids}, "
|
||||
f"reasoning_content={self.reasoning_content!r}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
|
||||
@@ -210,6 +210,7 @@ class ChatMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
multimodal_content: Optional[List[Any]] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
audio_content: Optional[str] = None
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
completion_token_ids: Optional[List[int]] = None
|
||||
@@ -272,6 +273,7 @@ class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
multimodal_content: Optional[List[Any]] = None
|
||||
audio_content: Optional[str] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
completion_token_ids: Optional[List[int]] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastdeploy.entrypoints.openai.usage_calculator import count_tokens
|
||||
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
|
||||
@@ -34,12 +35,14 @@ class ChatResponseProcessor:
|
||||
data_processor,
|
||||
enable_mm_output: Optional[bool] = False,
|
||||
eoi_token_id: Optional[int] = 101032,
|
||||
eoa_token_id: Optional[int] = 2048,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
decoder_base_url: Optional[str] = None,
|
||||
):
|
||||
self.data_processor = data_processor
|
||||
self.enable_mm_output = enable_mm_output
|
||||
self.eoi_token_id = eoi_token_id
|
||||
self.eoa_token_id = eoa_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
if decoder_base_url is not None:
|
||||
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
|
||||
@@ -47,6 +50,7 @@ class ChatResponseProcessor:
|
||||
self.decoder_client = None
|
||||
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
|
||||
self._end_image_code_request_output: Optional[Any] = None
|
||||
self._audio_buffer: Dict[Any] = {}
|
||||
self._multipart_buffer = []
|
||||
|
||||
def enable_multimodal_content(self):
|
||||
@@ -80,16 +84,54 @@ class ChatResponseProcessor:
|
||||
for request_output in request_outputs:
|
||||
api_server_logger.debug(f"request_output {request_output}")
|
||||
if not self.enable_mm_output:
|
||||
yield self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
outputs = request_output.get("outputs", None)
|
||||
token_ids = outputs.get("token_ids", None) if outputs is not None else None
|
||||
req_id = request_output.get("request_id", None)
|
||||
if outputs is not None and token_ids is not None and req_id is not None:
|
||||
decode_type = request_output["outputs"].get("decode_type", 0) or 0
|
||||
if decode_type == 0: # text
|
||||
tts = req_id in self._audio_buffer
|
||||
if token_ids[-1] == self.eos_token_id:
|
||||
all_audio_tokens = self._audio_buffer.pop(req_id, [])
|
||||
else:
|
||||
all_audio_tokens = None
|
||||
if inspect.iscoroutinefunction(self.data_processor.process_response_dict):
|
||||
response = await self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
audio_tokens=all_audio_tokens,
|
||||
tts=tts,
|
||||
)
|
||||
else:
|
||||
response = self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
audio_tokens=all_audio_tokens,
|
||||
tts=tts,
|
||||
)
|
||||
yield response
|
||||
elif decode_type == 2: # audio
|
||||
if self.eoa_token_id is not None and self.eoa_token_id in token_ids:
|
||||
continue
|
||||
if req_id in self._audio_buffer:
|
||||
self._audio_buffer[req_id].append(token_ids)
|
||||
else:
|
||||
self._audio_buffer[req_id] = [token_ids]
|
||||
else:
|
||||
yield self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
elif stream:
|
||||
decode_type = request_output["outputs"].get("decode_type", 0)
|
||||
token_ids = request_output["outputs"]["token_ids"]
|
||||
if decode_type == 0:
|
||||
if decode_type == 0: # text
|
||||
if self.eoi_token_id and self.eoi_token_id in token_ids:
|
||||
if self._mm_buffer:
|
||||
all_tokens = self._mm_buffer
|
||||
@@ -118,7 +160,7 @@ class ChatResponseProcessor:
|
||||
request_output["outputs"]["multipart"] = [text]
|
||||
yield request_output
|
||||
|
||||
elif decode_type == 1:
|
||||
elif decode_type == 1: # image
|
||||
self._mm_buffer.append(token_ids)
|
||||
self._end_image_code_request_output = request_output
|
||||
else:
|
||||
|
||||
@@ -329,6 +329,9 @@ class OpenAIServingChat:
|
||||
else:
|
||||
choice.delta.content = ""
|
||||
|
||||
if res["outputs"].get("audio_content", None) is not None:
|
||||
choice.delta.audio_content = res["outputs"]["audio_content"]
|
||||
|
||||
if request.return_token_ids:
|
||||
choice.delta.prompt_token_ids = list(prompt_token_ids)
|
||||
choice.delta.prompt_tokens = prompt_tokens
|
||||
@@ -389,6 +392,10 @@ class OpenAIServingChat:
|
||||
delta_message.multimodal_content = output["multipart"]
|
||||
else:
|
||||
delta_message.content = output["text"]
|
||||
|
||||
if output.get("audio_content", None) is not None:
|
||||
delta_message.audio_content = output["audio_content"]
|
||||
|
||||
if not res["finished"] and "delta_message" in output:
|
||||
delta_message_output = output["delta_message"]
|
||||
if delta_message_output is None:
|
||||
@@ -689,6 +696,9 @@ class OpenAIServingChat:
|
||||
else:
|
||||
message.content = output["text"]
|
||||
|
||||
if output.get("audio_content", None) is not None:
|
||||
message.audio_content = output["audio_content"]
|
||||
|
||||
logprobs_full_res = None
|
||||
draft_logprobs_full_res = None
|
||||
prompt_logprobs_full_res = None
|
||||
|
||||
@@ -56,6 +56,28 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(results[0]["processed"], True)
|
||||
self.assertEqual(results[0]["raw"]["outputs"]["text"], "hello")
|
||||
|
||||
async def test_audio_tts(self):
|
||||
"""不开启 multimodal,直接走 data_processor"""
|
||||
processor = ChatResponseProcessor(self.mock_data_processor)
|
||||
request_outputs = [
|
||||
{"request_id": "req1", "outputs": {"decode_type": 2, "token_ids": [[11, 22]]}},
|
||||
{"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [1]}},
|
||||
{"request_id": "req1", "outputs": {"decode_type": 2, "token_ids": [[11, 22]]}},
|
||||
{"request_id": "req1", "outputs": {"decode_type": 0, "token_ids": [2]}},
|
||||
]
|
||||
|
||||
results = [
|
||||
r
|
||||
async for r in processor.process_response_chat(
|
||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
self.assertEqual(results[0]["processed"], True)
|
||||
self.assertEqual(results[0]["raw"]["outputs"]["token_ids"], [1])
|
||||
self.assertEqual(results[1]["processed"], True)
|
||||
self.assertEqual(results[1]["raw"]["outputs"]["token_ids"], [2])
|
||||
|
||||
async def test_streaming_text_and_image(self):
|
||||
"""流式模式下:text → image → text"""
|
||||
request_outputs = [
|
||||
|
||||
Reference in New Issue
Block a user