mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Optimize] Improve perf for fd response token with internal adapter (#4947)
* [Optimize] Improve perf for fd response token with internal adapter * [Optimize] Improve perf for fd response token with internal adapter * fix * fix
This commit is contained in:
@@ -721,7 +721,7 @@ class EngineService:
|
||||
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
|
||||
self.internal_adapter = InternalAdapter(
|
||||
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
|
||||
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
|
||||
)
|
||||
else:
|
||||
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
||||
@@ -802,7 +802,10 @@ class EngineService:
|
||||
)
|
||||
# Since the request is not in scheduler
|
||||
# Send result by zmq directly
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.send_response_server.send_response(None, [[error_result]])
|
||||
else:
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
except Exception as e:
|
||||
self.llm_logger.error(
|
||||
f"Error happened while receiving new request from zmq, details={e}, "
|
||||
@@ -819,8 +822,11 @@ class EngineService:
|
||||
if len(results) == 0:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
self.send_response_server.send_response(request_id, contents)
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.send_response_server.send_response(None, results)
|
||||
else:
|
||||
for request_id, contents in results.items():
|
||||
self.send_response_server.send_response(request_id, contents)
|
||||
|
||||
except Exception as e:
|
||||
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
@@ -890,6 +896,8 @@ class EngineService:
|
||||
)
|
||||
del self.resource_manager.requests[task.request_id]
|
||||
del self.resource_manager.req_dict[task.request_id]
|
||||
task.finished = True
|
||||
self.scheduler.put_results([task])
|
||||
continue
|
||||
if task.error_code != 200:
|
||||
cur_task = self.resource_manager.requests[task.request_id]
|
||||
@@ -904,6 +912,7 @@ class EngineService:
|
||||
)
|
||||
continue
|
||||
self.token_processor.tokens_counter[task.request_id] = 1
|
||||
self.scheduler.put_results([task])
|
||||
self.resource_manager.insert_task_for_decoding(task)
|
||||
|
||||
else:
|
||||
|
||||
@@ -162,6 +162,10 @@ class LLMEngine:
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.launched_cache_manager_signal.value[0] = 1
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0]
|
||||
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0]
|
||||
|
||||
if api_server_pid is not None:
|
||||
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
|
||||
self.engine.start_zmq_service(api_server_pid)
|
||||
@@ -646,11 +650,12 @@ class LLMEngine:
|
||||
self.engine.scheduler.start(role, host_ip, disaggregate)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queue_for_dp_ipc = multiprocessing.Queue()
|
||||
result_queues_for_dp_ipc = []
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size):
|
||||
request_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
result_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
self.engine.scheduler.start(
|
||||
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queues_for_dp_ipc
|
||||
)
|
||||
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
@@ -683,7 +688,7 @@ class LLMEngine:
|
||||
i,
|
||||
None,
|
||||
request_queues_for_dp_ipc,
|
||||
result_queue_for_dp_ipc,
|
||||
result_queues_for_dp_ipc,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -27,7 +27,6 @@ import numpy as np
|
||||
|
||||
from fastdeploy.engine.common_engine import EngineService
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.utils import console_logger, envs, llm_logger
|
||||
|
||||
|
||||
@@ -53,6 +52,13 @@ class ExpertService:
|
||||
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
|
||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[
|
||||
local_data_parallel_id
|
||||
]
|
||||
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[
|
||||
local_data_parallel_id
|
||||
]
|
||||
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
|
||||
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
|
||||
self.cfg.disaggregate_info = None
|
||||
@@ -70,11 +76,9 @@ class ExpertService:
|
||||
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||
|
||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.internal_adapter = InternalAdapter(cfg=self.cfg, engine=self.engine, dp_rank=local_data_parallel_id)
|
||||
|
||||
def start(
|
||||
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queues_for_dp_ipc=None
|
||||
):
|
||||
"""
|
||||
Initializes the engine and starts its sub-services.
|
||||
@@ -87,14 +91,15 @@ class ExpertService:
|
||||
self.engine.start()
|
||||
if self.cfg.scheduler_config.name == "dp":
|
||||
self.cfg.init_cache_info()
|
||||
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
|
||||
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
|
||||
assert (request_queues_for_dp_ipc is not None) and (result_queues_for_dp_ipc is not None)
|
||||
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queues_for_dp_ipc)
|
||||
|
||||
if ipc_signal_suffix is not None:
|
||||
self.api_server_pid = ipc_signal_suffix
|
||||
self.engine.start_zmq_service(ipc_signal_suffix)
|
||||
else:
|
||||
ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
|
||||
self.engine.start_zmq_service(self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id])
|
||||
|
||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
@@ -155,7 +160,7 @@ class ExpertService:
|
||||
|
||||
|
||||
def start_data_parallel_service(
|
||||
cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
|
||||
cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queues_for_dp_ipc=None
|
||||
):
|
||||
"""
|
||||
Start expert service
|
||||
@@ -164,7 +169,7 @@ def start_data_parallel_service(
|
||||
|
||||
try:
|
||||
expert_service.start(
|
||||
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queues_for_dp_ipc
|
||||
)
|
||||
|
||||
def deamon_thread():
|
||||
|
||||
@@ -85,6 +85,8 @@ class Request:
|
||||
prefill_start_index: int = 0,
|
||||
prefill_end_index: int = 0,
|
||||
num_computed_tokens: int = 0,
|
||||
# for internal adapter
|
||||
ic_req_data: Optional[dict] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
@@ -150,6 +152,7 @@ class Request:
|
||||
self.extend_block_tables = []
|
||||
# dp
|
||||
self.dp_rank = dp_rank
|
||||
self.ic_req_data = ic_req_data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
@@ -194,6 +197,7 @@ class Request:
|
||||
video_end=d.get("video_end", 0),
|
||||
audio_end=d.get("audio_end", 0),
|
||||
dp_rank=d.get("dp_rank", None),
|
||||
ic_req_data=d.get("ic_req_data", None),
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -244,6 +248,7 @@ class Request:
|
||||
"image_end": self.image_end,
|
||||
"video_end": self.video_end,
|
||||
"audio_end": self.audio_end,
|
||||
"ic_req_data": self.ic_req_data,
|
||||
}
|
||||
add_params = [
|
||||
"guided_json",
|
||||
@@ -430,6 +435,9 @@ class RequestOutput:
|
||||
num_cached_tokens: Optional[int] = 0,
|
||||
error_code: Optional[int] = 200,
|
||||
error_msg: Optional[str] = None,
|
||||
# for internal adapter
|
||||
ic_req_data: Optional[dict] = None,
|
||||
prompt_token_ids_len: Optional[int] = 0,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
@@ -440,6 +448,8 @@ class RequestOutput:
|
||||
self.num_cached_tokens = num_cached_tokens
|
||||
self.error_code = error_code
|
||||
self.error_msg = error_msg
|
||||
self.ic_req_data = ic_req_data
|
||||
self.prompt_token_ids_len = prompt_token_ids_len
|
||||
|
||||
if prompt_token_ids is None:
|
||||
self.prompt_token_ids = []
|
||||
@@ -494,4 +504,6 @@ class RequestOutput:
|
||||
"num_cached_tokens": self.num_cached_tokens,
|
||||
"error_code": self.error_code,
|
||||
"error_msg": self.error_msg,
|
||||
"ic_req_data": self.ic_req_data,
|
||||
"prompt_token_ids_len": self.prompt_token_ids_len,
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Whether to use HuggingFace tokenizer.
|
||||
"FD_USE_HF_TOKENIZER": lambda: bool(int(os.getenv("FD_USE_HF_TOKENIZER", "0"))),
|
||||
# Set the high watermark (HWM) for receiving data during ZMQ initialization
|
||||
"FD_ZMQ_SNDHWM": lambda: os.getenv("FD_ZMQ_SNDHWM", 64000),
|
||||
"FD_ZMQ_SNDHWM": lambda: os.getenv("FD_ZMQ_SNDHWM", 0),
|
||||
# cache kv quant params directory
|
||||
"FD_CACHE_PARAMS": lambda: os.getenv("FD_CACHE_PARAMS", "none"),
|
||||
# Set attention backend. "NATIVE_ATTN", "APPEND_ATTN"
|
||||
@@ -107,6 +107,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
|
||||
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
||||
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_RECV_REQUEST_SERVER_PORTS": os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORTS", "8200"),
|
||||
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_SEND_RESPONSE_SERVER_PORTS": os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORTS", "8201"),
|
||||
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||
# Whether to enable cache task in decode node
|
||||
|
||||
@@ -35,6 +35,9 @@ class ZmqServerBase(ABC):
|
||||
def __init__(self):
|
||||
self.cached_results = defaultdict(list)
|
||||
self.response_token_lock = threading.Lock()
|
||||
self.response_handle_per_step = None
|
||||
self.response_handle_name_per_step = None
|
||||
self.batch_id_per_step = 0
|
||||
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
@@ -125,16 +128,20 @@ class ZmqServerBase(ABC):
|
||||
with self.response_token_lock:
|
||||
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
need_send_after_finished_inference = False
|
||||
with self.mutex:
|
||||
self.req_dict[req_id_str] = client
|
||||
if req_id_str in self.cached_results:
|
||||
if self.cached_results[req_id_str][-1][-1].finished:
|
||||
need_send_after_finished_inference = True
|
||||
if need_send_after_finished_inference:
|
||||
self.send_response(req_id_str, [])
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id_str}")
|
||||
self.req_dict.pop(req_id_str, None)
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
with self.mutex:
|
||||
self.response_handle_per_step = client
|
||||
else:
|
||||
need_send_after_finished_inference = False
|
||||
with self.mutex:
|
||||
self.req_dict[req_id_str] = client
|
||||
if req_id_str in self.cached_results:
|
||||
if self.cached_results[req_id_str][-1][-1].finished:
|
||||
need_send_after_finished_inference = True
|
||||
if need_send_after_finished_inference:
|
||||
self.send_response(req_id_str, [])
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id_str}")
|
||||
self.req_dict.pop(req_id_str, None)
|
||||
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
@@ -143,7 +150,39 @@ class ZmqServerBase(ABC):
|
||||
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
|
||||
continue
|
||||
|
||||
def send_response(self, req_id, data):
|
||||
def _send_response_per_step(self, batch_id, data):
|
||||
"""
|
||||
Send generated token result to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
need_send_data = []
|
||||
with self.mutex:
|
||||
if self.response_handle_per_step is None:
|
||||
self.cached_results["data"].extend(data)
|
||||
else:
|
||||
need_send_data = self.cached_results["data"]
|
||||
self.cached_results["data"] = []
|
||||
if self.response_handle_per_step is not None:
|
||||
try:
|
||||
if data:
|
||||
need_send_data.extend(data)
|
||||
start_send = time.time()
|
||||
result = msgpack.packb(
|
||||
[[response.to_dict() for response in send_data_list] for send_data_list in need_send_data]
|
||||
)
|
||||
with self.response_token_lock:
|
||||
self.socket.send_multipart([self.response_handle_per_step, b"", result])
|
||||
llm_logger.info(
|
||||
f"send_multipart result: step {self.batch_id_per_step} lens {len(need_send_data)} elapse: {time.time()-start_send}"
|
||||
)
|
||||
self.batch_id_per_step += 1
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
def _send_response_per_query(self, req_id, data):
|
||||
"""
|
||||
Send generated token result to client.
|
||||
"""
|
||||
@@ -187,6 +226,12 @@ class ZmqServerBase(ABC):
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
self.req_dict.pop(req_id, None)
|
||||
|
||||
def send_response(self, req_id, data):
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self._send_response_per_step(req_id, data)
|
||||
else:
|
||||
self._send_response_per_query(req_id, data)
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
@@ -201,6 +246,7 @@ class ZmqIpcServer(ZmqServerBase):
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
super(ZmqIpcServer, self).__init__()
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.cached_results = defaultdict(list)
|
||||
@@ -261,6 +307,7 @@ class ZmqTcpServer(ZmqServerBase):
|
||||
"""
|
||||
|
||||
def __init__(self, port, mode):
|
||||
super(ZmqTcpServer, self).__init__()
|
||||
self.mode = mode
|
||||
self.port = port
|
||||
self.cached_results = defaultdict(list)
|
||||
|
||||
@@ -232,6 +232,7 @@ class TokenProcessor:
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics,
|
||||
ic_req_data=task.ic_req_data,
|
||||
)
|
||||
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
@@ -399,9 +400,15 @@ class TokenProcessor:
|
||||
if task_id in self.resource_manager.req_dict:
|
||||
del self.resource_manager.req_dict[task_id]
|
||||
|
||||
num_blocks_used_by_tasks = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
|
||||
main_process_metrics.set_value("available_gpu_block_num", self.resource_manager.total_block_number() - num_blocks_used_by_tasks)
|
||||
main_process_metrics.set_value("batch_size", self.resource_manager.max_num_seqs - self.resource_manager.available_batch())
|
||||
num_blocks_used_by_tasks = sum(
|
||||
[len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list]
|
||||
)
|
||||
main_process_metrics.set_value(
|
||||
"available_gpu_block_num", self.resource_manager.total_block_number() - num_blocks_used_by_tasks
|
||||
)
|
||||
main_process_metrics.set_value(
|
||||
"batch_size", self.resource_manager.max_num_seqs - self.resource_manager.available_batch()
|
||||
)
|
||||
main_process_metrics.set_value("available_batch_size", self.resource_manager.available_batch())
|
||||
|
||||
if task_id in self.tokens_counter:
|
||||
@@ -535,6 +542,8 @@ class TokenProcessor:
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics,
|
||||
ic_req_data=task.ic_req_data,
|
||||
prompt_token_ids_len=task.prompt_token_ids_len,
|
||||
)
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
|
||||
@@ -223,10 +223,10 @@ class DPScheduler:
|
||||
splitwise_role,
|
||||
)
|
||||
|
||||
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
|
||||
def start(self, dp_rank: int, request_queues: List[Queue], result_queues: Queue):
|
||||
self.dp_rank = dp_rank
|
||||
self.request_queues = request_queues
|
||||
self.result_queue = result_queue
|
||||
self.result_queues = result_queues
|
||||
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
|
||||
self._scheduler.scheduler_logger = self.scheduler_logger
|
||||
threading.Thread(target=self._put_requests_to_local).start()
|
||||
@@ -252,7 +252,7 @@ class DPScheduler:
|
||||
results = self._scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
continue
|
||||
self.result_queue.put(results)
|
||||
self.result_queues[self.dp_rank].put(results)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
@@ -273,4 +273,4 @@ class DPScheduler:
|
||||
self._scheduler.put_results(results)
|
||||
|
||||
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||
return self.result_queue.get()
|
||||
return self.result_queues[self.dp_rank].get()
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
from fastdeploy.utils import envs, scheduler_logger
|
||||
|
||||
|
||||
class LocalScheduler:
|
||||
@@ -79,6 +79,7 @@ class LocalScheduler:
|
||||
|
||||
self.requests: Dict[str, ScheduledRequest] = dict()
|
||||
self.responses: Dict[str, List[ScheduledResponse]] = dict()
|
||||
self.batch_responses_per_step: List[List[ScheduledResponse]] = list()
|
||||
|
||||
self.wait_request_timeout = 10
|
||||
self.wait_response_timeout = 0.001
|
||||
@@ -298,6 +299,7 @@ class LocalScheduler:
|
||||
scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
|
||||
|
||||
with self.mutex:
|
||||
self.batch_responses_per_step.append([response.raw for response in responses])
|
||||
for response in responses:
|
||||
if response.request_id not in self.requests:
|
||||
scheduler_logger.warning(f"Scheduler has received a expired response: {[response.request_id]}")
|
||||
@@ -336,11 +338,15 @@ class LocalScheduler:
|
||||
|
||||
def _get_results():
|
||||
responses = self.responses
|
||||
batch_responses_per_step = self.batch_responses_per_step
|
||||
self.responses = dict()
|
||||
return responses
|
||||
self.batch_responses_per_step = list()
|
||||
return responses, batch_responses_per_step
|
||||
|
||||
with self.responses_not_empty:
|
||||
responses = self.responses_not_empty.wait_for(_get_results, self.wait_response_timeout)
|
||||
responses, batch_responses_per_step = self.responses_not_empty.wait_for(
|
||||
_get_results, self.wait_response_timeout
|
||||
)
|
||||
|
||||
results = dict()
|
||||
for request_id, resps in responses.items():
|
||||
@@ -353,4 +359,7 @@ class LocalScheduler:
|
||||
if finished:
|
||||
self._recycle(request_id)
|
||||
scheduler_logger.info(f"Scheduler has pulled a finished response: {[request_id]}")
|
||||
return results
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
return batch_responses_per_step
|
||||
else:
|
||||
return results
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Dict
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
@@ -500,19 +500,5 @@ class SplitwiseConnector:
|
||||
"""
|
||||
tasks = []
|
||||
for task in payload:
|
||||
tasks.append(
|
||||
RequestOutput(
|
||||
request_id=task["request_id"],
|
||||
outputs=CompletionOutput(
|
||||
index=task["outputs"]["index"],
|
||||
send_idx=0,
|
||||
token_ids=task["outputs"]["token_ids"],
|
||||
draft_token_ids=task["outputs"]["draft_token_ids"],
|
||||
),
|
||||
finished=True,
|
||||
num_cached_tokens=task["num_cached_tokens"],
|
||||
error_code=task["error_code"],
|
||||
error_msg=task["error_msg"],
|
||||
)
|
||||
)
|
||||
tasks.append(RequestOutput.from_dict(task))
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||
|
||||
Reference in New Issue
Block a user