mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[Bug fix] Fix bug for d blocks not enough (#3479)
* Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug * Fix bug for memory allocation * Fix bug for D blocks not enough * fix bug when d blocks not enough * fix bug when d blocks not enough * fix cache message recycle step * fix cache message recycle step * Fix step_idx recycle
This commit is contained in:
@@ -18,6 +18,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
@@ -32,7 +33,8 @@ class ZmqServerBase(ABC):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
self.cached_results = defaultdict(list)
|
||||
self.response_token_lock = threading.Lock()
|
||||
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
@@ -89,6 +91,21 @@ class ZmqServerBase(ABC):
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def recv_result_handle(self):
|
||||
while True:
|
||||
try:
|
||||
with self.response_token_lock:
|
||||
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
with self.mutex:
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
except Exception as e:
|
||||
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
|
||||
continue
|
||||
|
||||
def send_response(self, req_id, data):
|
||||
"""
|
||||
Send generated token result to client.
|
||||
@@ -96,36 +113,46 @@ class ZmqServerBase(ABC):
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
|
||||
while self.running:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
try:
|
||||
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
new_data = []
|
||||
has_result_handle = False
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
self.cached_results[req_id].append(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.socket.send_multipart([self.req_dict[req_id], b"", result])
|
||||
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
has_result_handle = True
|
||||
if req_id in self.cached_results:
|
||||
for history_data in self.cached_results[req_id]:
|
||||
new_data.extend(history_data)
|
||||
llm_logger.info(
|
||||
f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}"
|
||||
)
|
||||
del self.cached_results[req_id]
|
||||
if has_result_handle:
|
||||
try:
|
||||
new_data.extend(data)
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(new_data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in new_data])
|
||||
with self.response_token_lock:
|
||||
self.socket.send_multipart([self.req_dict[req_id], b"", result])
|
||||
llm_logger.debug(
|
||||
f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it")
|
||||
if req_id in self.cached_results:
|
||||
del self.cached_results[req_id]
|
||||
else:
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
self.req_dict.pop(req_id, None)
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
@@ -143,6 +170,7 @@ class ZmqIpcServer(ZmqServerBase):
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.cached_results = defaultdict(list)
|
||||
if mode == zmq.PULL:
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
elif mode == zmq.ROUTER:
|
||||
@@ -150,6 +178,7 @@ class ZmqIpcServer(ZmqServerBase):
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
self.mutex = threading.Lock()
|
||||
self.response_token_lock = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
@@ -201,6 +230,7 @@ class ZmqTcpServer(ZmqServerBase):
|
||||
def __init__(self, port, mode):
|
||||
self.mode = mode
|
||||
self.port = port
|
||||
self.cached_results = defaultdict(list)
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
|
||||
@@ -209,6 +239,8 @@ class ZmqTcpServer(ZmqServerBase):
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
self.mutex = threading.Lock()
|
||||
self.response_token_lock = threading.Lock()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
|
Reference in New Issue
Block a user