[LLM] support send batch data and aggregate data (#2860)

* [LLM] support send batch data and aggregate data

* [LLM] fix ci bugs

* [LLM] fix ci bugs

* [LLM] fix ci bugs

* [LLM] fix ci bugs

* [LLM] update
This commit is contained in:
ltd0924
2025-07-16 23:42:20 +08:00
committed by GitHub
parent 63d6e7ce06
commit d245d1ca6c
11 changed files with 267 additions and 208 deletions

View File

@@ -263,10 +263,11 @@ class LLMEngine(object):
try: try:
results = self.scheduler.get_results() results = self.scheduler.get_results()
if len(results) == 0: if len(results) == 0:
time.sleep(0.001) time.sleep(0.005)
continue
for request_id, contents in results.items(): for request_id, contents in results.items():
for result in contents: self.zmq_server.send_multipart(request_id, contents)
self.zmq_server.send_multipart(request_id, result)
except Exception as e: except Exception as e:
llm_logger.error("Unexcepted error happend: {}, {}".format( llm_logger.error("Unexcepted error happend: {}, {}".format(
e, str(traceback.format_exc()))) e, str(traceback.format_exc())))

View File

@@ -20,7 +20,7 @@ import time
from dataclasses import asdict, dataclass, fields from dataclasses import asdict, dataclass, fields
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import numpy import numpy as np
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.utils import data_processor_logger from fastdeploy.utils import data_processor_logger
@@ -181,7 +181,7 @@ class Request:
f"sampling_params={self.sampling_params})") f"sampling_params={self.sampling_params})")
@dataclass @dataclass(slots=True)
class CompletionOutput: class CompletionOutput:
"""The output data of one completion output of a request. """The output data of one completion output of a request.
@@ -235,7 +235,7 @@ class CompletionOutput:
f"reasoning_content={self.reasoning_content!r}") f"reasoning_content={self.reasoning_content!r}")
@dataclass @dataclass(slots=True)
class RequestMetrics: class RequestMetrics:
"""Metrics associated with a request. """Metrics associated with a request.
@@ -310,6 +310,10 @@ class RequestOutput:
None if decoder-only. None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit. num_cached_tokens: The number of tokens with prefix cache hit.
""" """
__slots__ = (
'request_id', 'prompt', 'prompt_token_ids', 'outputs',
'finished', 'metrics', 'num_cached_tokens', 'error_code', 'error_msg'
)
def __init__( def __init__(
self, self,
@@ -333,6 +337,12 @@ class RequestOutput:
self.error_code = error_code self.error_code = error_code
self.error_msg = error_msg self.error_msg = error_msg
if prompt_token_ids is None:
self.prompt_token_ids = []
elif isinstance(self.prompt_token_ids, np.ndarray):
self.prompt_token_ids = self.prompt_token_ids.tolist()
def add(self, next_output: "RequestOutput") -> None: def add(self, next_output: "RequestOutput") -> None:
"""Merge RequestOutput into this one""" """Merge RequestOutput into this one"""
@@ -365,11 +375,6 @@ class RequestOutput:
def to_dict(self): def to_dict(self):
"""convert RequestOutput into a serializable dict """ """convert RequestOutput into a serializable dict """
if self.prompt_token_ids is None:
self.prompt_token_ids = []
if type(self.prompt_token_ids) is numpy.ndarray:
self.prompt_token_ids = self.prompt_token_ids.tolist()
return { return {
"request_id": self.request_id, "request_id": self.request_id,

View File

@@ -169,6 +169,8 @@ class LLM:
# get output # get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm) outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
for i in range(len(outputs)):
outputs[i].prompt = prompts[i]
return outputs return outputs
def chat( def chat(

View File

@@ -21,6 +21,7 @@ import traceback
import uuid import uuid
from typing import List, Optional from typing import List, Optional
import msgpack
import aiozmq import aiozmq
from aiozmq import zmq from aiozmq import zmq
@@ -143,6 +144,8 @@ class OpenAIServingChat:
dealer.write([b"", request_id.encode('utf-8')]) dealer.write([b"", request_id.encode('utf-8')])
choices = [] choices = []
current_waiting_time = 0 current_waiting_time = 0
if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking")
while num_choices > 0: while num_choices > 0:
try: try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10) raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
@@ -158,14 +161,13 @@ class OpenAIServingChat:
raise ValueError(f"Engine is not healthy: {msg}") raise ValueError(f"Engine is not healthy: {msg}")
else: else:
current_waiting_time = 0 current_waiting_time = 0
await asyncio.sleep(0.1) await asyncio.sleep(0.01)
continue continue
response = msgpack.unpackb(raw_data[-1])
res = json.loads(raw_data[-1].decode('utf-8')) for res in response:
if res.get("error_code", 200) != 200: if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"])) raise ValueError("{}".format(res["error_msg"]))
if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking")
self.engine_client.data_processor.process_response_dict( self.engine_client.data_processor.process_response_dict(
res, stream=True, enable_thinking=enable_thinking) res, stream=True, enable_thinking=enable_thinking)
@@ -258,6 +260,11 @@ class OpenAIServingChat:
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = [] choices = []
if choices:
chunk.choices = choices
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = []
if include_usage: if include_usage:
completion_tokens = previous_num_tokens completion_tokens = previous_num_tokens
@@ -321,7 +328,9 @@ class OpenAIServingChat:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
continue continue
data = json.loads(raw_data[-1].decode('utf-8')) response = msgpack.unpackb(raw_data[-1])
task_is_finished = False
for data in response:
if data.get("error_code", 200) != 200: if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"])) raise ValueError("{}".format(data["error_msg"]))
if request.metadata is not None: if request.metadata is not None:
@@ -348,6 +357,9 @@ class OpenAIServingChat:
logprob_contents.extend(logprobs_res.content) logprob_contents.extend(logprobs_res.content)
if data["finished"]: if data["finished"]:
final_res = data final_res = data
task_is_finished = True
break
if task_is_finished:
break break
finally: finally:
dealer.close() dealer.close()

View File

@@ -17,6 +17,7 @@
import asyncio import asyncio
import aiozmq import aiozmq
import json import json
import msgpack
from aiozmq import zmq from aiozmq import zmq
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
import time import time
@@ -179,7 +180,8 @@ class OpenAIServingCompletion:
current_waiting_time = 0 current_waiting_time = 0
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
continue continue
data = json.loads(raw_data[-1].decode("utf-8")) response = msgpack.unpackb(raw_data[-1])
for data in response:
rid = int(data["request_id"].split("-")[-1]) rid = int(data["request_id"].split("-")[-1])
if data.get("error_code", 200) != 200: if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"])) raise ValueError("{}".format(data["error_msg"]))
@@ -191,6 +193,7 @@ class OpenAIServingCompletion:
data["output_token_ids"] = output_tokens[rid] data["output_token_ids"] = output_tokens[rid]
valid_results[rid] = data valid_results[rid] = data
num_choices -= 1 num_choices -= 1
break
return self.request_output_to_completion_response( return self.request_output_to_completion_response(
final_res_batch=valid_results, final_res_batch=valid_results,
@@ -238,6 +241,12 @@ class OpenAIServingCompletion:
if request.suffix is not None and request.suffix.get("max_streaming_response_tokens", 1) > 1: if request.suffix is not None and request.suffix.get("max_streaming_response_tokens", 1) > 1:
max_streaming_response_tokens = request.suffix["max_streaming_response_tokens"] max_streaming_response_tokens = request.suffix["max_streaming_response_tokens"]
choices = [] choices = []
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices
)
current_waiting_time = 0 current_waiting_time = 0
while num_choices > 0: while num_choices > 0:
@@ -256,7 +265,8 @@ class OpenAIServingCompletion:
continue continue
res = json.loads(raw_data[-1].decode('utf-8')) response = msgpack.unpackb(raw_data[-1])
for res in response:
idx = int(res["request_id"].split("-")[-1]) idx = int(res["request_id"].split("-")[-1])
if res.get("error_code", 200) != 200: if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"])) raise ValueError("{}".format(res["error_msg"]))
@@ -284,7 +294,6 @@ class OpenAIServingCompletion:
inference_start_time[idx] = res['metrics']['inference_start_time'] inference_start_time[idx] = res['metrics']['inference_start_time']
else: else:
arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx] arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx]
# api_server_logger.info(f"{arrival_time}")
output = res["outputs"] output = res["outputs"]
@@ -314,9 +323,9 @@ class OpenAIServingCompletion:
model=model_name, model=model_name,
choices=choices choices=choices
) )
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = [] choices = []
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
if res["finished"]: if res["finished"]:
num_choices -= 1 num_choices -= 1
@@ -332,6 +341,10 @@ class OpenAIServingCompletion:
) )
) )
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
if choices:
chunk.choices = choices
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = []
except Exception as e: except Exception as e:

View File

@@ -101,6 +101,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use DeepGemm for FP8 blockwise MoE. # Whether to use DeepGemm for FP8 blockwise MoE.
"FD_USE_DEEP_GEMM": "FD_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))), lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
# Whether to use aggregate send.
"FD_USE_AGGREGATE_SEND":
lambda: bool(int(os.getenv("FD_USE_AGGREGATE_SEND", "0"))),
} }

View File

@@ -20,6 +20,7 @@ import threading
import time import time
import zmq import zmq
import msgpack
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.utils import llm_logger from fastdeploy.utils import llm_logger
@@ -37,6 +38,7 @@ class ZmqClient:
self.router_path = f"/dev/shm/router_{name}.ipc" self.router_path = f"/dev/shm/router_{name}.ipc"
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
self.mutex = threading.Lock() self.mutex = threading.Lock()
self.req_dict = dict() self.req_dict = dict()
@@ -93,6 +95,16 @@ class ZmqClient:
""" """
return self.socket.recv_pyobj() return self.socket.recv_pyobj()
def pack_aggregated_data(self, data):
"""
Aggregate multiple responses into one and send them to the client.
"""
result = data[0]
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
return result
def send_multipart(self, req_id, data): def send_multipart(self, req_id, data):
""" """
Send a multipart message to the router socket. Send a multipart message to the router socket.
@@ -116,14 +128,22 @@ class ZmqClient:
break break
try: try:
result = json.dumps(data.to_dict()).encode('utf-8') start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(data)
else:
result = msgpack.packb([response.to_dict() for response in data])
self.router.send_multipart([self.req_dict[req_id], b'', result]) self.router.send_multipart([self.req_dict[req_id], b'', result])
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
except Exception as e: except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}") llm_logger.error(f"Send result to zmq client failed: {e}")
if data.finished: if data[-1].finished:
with self.mutex: with self.mutex:
self.req_dict.pop(data.request_id, None) self.req_dict.pop(req_id, None)
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
def receive_json_once(self, block=False): def receive_json_once(self, block=False):
""" """

View File

@@ -505,8 +505,6 @@ class TokenProcessor(object):
result.outputs.token_ids.append(token_id) result.outputs.token_ids.append(token_id)
if token_id in task.eos_token_ids or is_prefill or recovery_stop: if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True result.finished = True
result.prompt = task.prompt
result.prompt_token_ids = task.prompt_token_ids
if recovery_stop: if recovery_stop:
result.error_msg = "Recover is not supported, the result is incomplete!" result.error_msg = "Recover is not supported, the result is incomplete!"
llm_logger.info( llm_logger.info(

View File

@@ -29,9 +29,11 @@ triton==3.3
use-triton-in-paddle use-triton-in-paddle
crcmod crcmod
fastsafetensors==0.1.14 fastsafetensors==0.1.14
msgpack
opentelemetry-api>=1.24.0 opentelemetry-api>=1.24.0
opentelemetry-sdk>=1.24.0 opentelemetry-sdk>=1.24.0
opentelemetry-instrumentation-redis opentelemetry-instrumentation-redis
opentelemetry-instrumentation-mysql opentelemetry-instrumentation-mysql
opentelemetry-distro  opentelemetry-distro 
opentelemetry-exporter-otlp opentelemetry-exporter-otlp

View File

@@ -27,3 +27,4 @@ moviepy
use-triton-in-paddle use-triton-in-paddle
crcmod crcmod
fastsafetensors==0.1.14 fastsafetensors==0.1.14
msgpack

View File

@@ -27,3 +27,4 @@ moviepy
use-triton-in-paddle use-triton-in-paddle
crcmod crcmod
fastsafetensors==0.1.14 fastsafetensors==0.1.14
msgpack