mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[LLM] support send batch data and aggregate data (#2860)
* [LLM] support send batch data and aggregate data * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] update
This commit is contained in:
@@ -20,6 +20,7 @@ import threading
|
||||
import time
|
||||
|
||||
import zmq
|
||||
import msgpack
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
@@ -37,6 +38,7 @@ class ZmqClient:
|
||||
self.router_path = f"/dev/shm/router_{name}.ipc"
|
||||
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
@@ -93,6 +95,16 @@ class ZmqClient:
|
||||
"""
|
||||
return self.socket.recv_pyobj()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
def send_multipart(self, req_id, data):
|
||||
"""
|
||||
Send a multipart message to the router socket.
|
||||
@@ -116,14 +128,22 @@ class ZmqClient:
|
||||
break
|
||||
|
||||
try:
|
||||
result = json.dumps(data.to_dict()).encode('utf-8')
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.router.send_multipart([self.req_dict[req_id], b'', result])
|
||||
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data.finished:
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(data.request_id, None)
|
||||
self.req_dict.pop(req_id, None)
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user