mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Bug fix] Test td cache messager (#3242)
* support disable cache task in decode node * fix busg * Update engine.py * Update expert_service.py * Update splitwise_connector.py * Optimize log for debug * Optimize log for debug * fix bug --------- Co-authored-by: ltd0924 <ltd0924@sina.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
This commit is contained in:
@@ -17,8 +17,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
@@ -196,7 +197,9 @@ class CacheMessager:
|
|||||||
|
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.cache_info = dict()
|
self.cache_info = dict()
|
||||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
|
self.rank_id = (
|
||||||
|
self.rank + local_data_parallel_id * self.nranks
|
||||||
|
) # align with engine worker rank (paddle.distributed.launch)
|
||||||
|
|
||||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||||
connect_rdma_thread.daemon = True
|
connect_rdma_thread.daemon = True
|
||||||
@@ -284,7 +287,7 @@ class CacheMessager:
|
|||||||
if not self.cache_info:
|
if not self.cache_info:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
continue
|
continue
|
||||||
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||||
for req_id, item in list(self.cache_info.items()):
|
for req_id, item in list(self.cache_info.items()):
|
||||||
if "status" not in item:
|
if "status" not in item:
|
||||||
continue
|
continue
|
||||||
@@ -364,7 +367,7 @@ class CacheMessager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"prefill layerwise send cache thread has exception: {e}")
|
logger.info(f"prefill layerwise send cache thread has exception: {e}")
|
||||||
|
|
||||||
def _handle_connect_task(self):
|
def _handle_connect_task(self):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -465,7 +468,8 @@ def main():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
logger = get_logger("cache_messager", "cache_messager.log")
|
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||||
|
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
|
||||||
|
|
||||||
logger.info("create cache messager...")
|
logger.info("create cache messager...")
|
||||||
logger.info(f"{args}")
|
logger.info(f"{args}")
|
||||||
|
@@ -113,6 +113,8 @@ class LLMEngine:
|
|||||||
|
|
||||||
self.start_queue_service()
|
self.start_queue_service()
|
||||||
|
|
||||||
|
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||||
|
|
||||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
self.resource_manager = ResourceManagerV1(
|
self.resource_manager = ResourceManagerV1(
|
||||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||||
@@ -630,11 +632,15 @@ class LLMEngine:
|
|||||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||||
self.insert_tasks([task])
|
self.insert_tasks([task])
|
||||||
else:
|
else:
|
||||||
|
if not self.enable_decode_cache_task:
|
||||||
|
task.error_msg = "Not enough resources"
|
||||||
new_waiting.append(task)
|
new_waiting.append(task)
|
||||||
|
|
||||||
if new_waiting:
|
if new_waiting:
|
||||||
self.waiting_requests.extend(new_waiting)
|
if not self.enable_decode_cache_task:
|
||||||
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
self.split_connector.send_cache_infos(new_waiting, -1)
|
||||||
|
else:
|
||||||
|
self.waiting_requests.extend(new_waiting)
|
||||||
|
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
@@ -805,6 +811,22 @@ class LLMEngine:
|
|||||||
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||||
|
if self.cfg.splitwise_role != "mixed":
|
||||||
|
status, msg = self.split_connector.check_decode_allocated(task)
|
||||||
|
if not status:
|
||||||
|
llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||||
|
self.scheduler.put_results(
|
||||||
|
[
|
||||||
|
RequestOutput(
|
||||||
|
request_id=task.request_id,
|
||||||
|
finished=True,
|
||||||
|
error_code=500,
|
||||||
|
error_msg=msg,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tasks.remove(task)
|
||||||
|
continue
|
||||||
if task.sampling_params.bad_words is not None:
|
if task.sampling_params.bad_words is not None:
|
||||||
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
||||||
|
|
||||||
@@ -1020,7 +1042,6 @@ class LLMEngine:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting sub services: {e}")
|
print(f"Error extracting sub services: {e}")
|
||||||
|
|
||||||
|
|
||||||
for worker_queue in self.engine_worker_queue_server:
|
for worker_queue in self.engine_worker_queue_server:
|
||||||
worker_queue.cleanup()
|
worker_queue.cleanup()
|
||||||
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
||||||
|
@@ -26,6 +26,7 @@ from collections import deque
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import RequestOutput
|
||||||
from fastdeploy.engine.resource_manager import ResourceManager
|
from fastdeploy.engine.resource_manager import ResourceManager
|
||||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
@@ -34,6 +35,7 @@ from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
|||||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||||
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
|
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
|
||||||
|
|
||||||
|
|
||||||
class ExpertService:
|
class ExpertService:
|
||||||
"""
|
"""
|
||||||
Engine class responsible for managing the Large Language Model (LLM) operations.
|
Engine class responsible for managing the Large Language Model (LLM) operations.
|
||||||
@@ -146,7 +148,7 @@ class ExpertService:
|
|||||||
|
|
||||||
# Start TokenProcessor thread
|
# Start TokenProcessor thread
|
||||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
|
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
|
||||||
|
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK
|
||||||
self.token_processor.run()
|
self.token_processor.run()
|
||||||
|
|
||||||
self.cfg.init_cache_info()
|
self.cfg.init_cache_info()
|
||||||
@@ -262,11 +264,15 @@ class ExpertService:
|
|||||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||||
self.insert_tasks([task])
|
self.insert_tasks([task])
|
||||||
else:
|
else:
|
||||||
|
if not self.enable_decode_cache_task:
|
||||||
|
task.error_msg = "Not enough resources"
|
||||||
new_waiting.append(task)
|
new_waiting.append(task)
|
||||||
|
|
||||||
if new_waiting:
|
if new_waiting:
|
||||||
self.waiting_requests.extend(new_waiting)
|
if not self.enable_decode_cache_task:
|
||||||
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
self.split_connector.send_cache_infos(new_waiting, -1)
|
||||||
|
else:
|
||||||
|
self.waiting_requests.extend(new_waiting)
|
||||||
|
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
@@ -310,8 +316,24 @@ class ExpertService:
|
|||||||
if not isinstance(tasks, list):
|
if not isinstance(tasks, list):
|
||||||
tasks = [tasks]
|
tasks = [tasks]
|
||||||
|
|
||||||
for item in tasks:
|
for task in tasks:
|
||||||
item.schedule_start_time = time.time()
|
if self.cfg.splitwise_role != "mixed":
|
||||||
|
status, msg = self.split_connector.check_decode_allocated(task)
|
||||||
|
if not status:
|
||||||
|
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||||
|
self.scheduler.put_results(
|
||||||
|
[
|
||||||
|
RequestOutput(
|
||||||
|
request_id=task.request_id,
|
||||||
|
finished=True,
|
||||||
|
error_code=500,
|
||||||
|
error_msg=msg,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tasks.remove(task)
|
||||||
|
continue
|
||||||
|
task.schedule_start_time = time.time()
|
||||||
|
|
||||||
available_batch = np.sum(self.resource_manager.stop_flags)
|
available_batch = np.sum(self.resource_manager.stop_flags)
|
||||||
if len(tasks) > available_batch:
|
if len(tasks) > available_batch:
|
||||||
|
@@ -90,6 +90,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||||
# Whether to use PLUGINS.
|
# Whether to use PLUGINS.
|
||||||
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
|
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
|
||||||
|
# Whether to enable cache task in decode node
|
||||||
|
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -26,8 +26,6 @@ from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
|
|||||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
|
||||||
|
|
||||||
|
|
||||||
class SplitwiseConnector:
|
class SplitwiseConnector:
|
||||||
"""
|
"""
|
||||||
@@ -45,6 +43,12 @@ class SplitwiseConnector:
|
|||||||
resource_manager (object): Resource manager object.
|
resource_manager (object): Resource manager object.
|
||||||
"""
|
"""
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||||
|
self.logger = get_logger(
|
||||||
|
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.engine_worker_queue = worker_queue
|
self.engine_worker_queue = worker_queue
|
||||||
self.resource_manager = resource_manager
|
self.resource_manager = resource_manager
|
||||||
@@ -52,6 +56,7 @@ class SplitwiseConnector:
|
|||||||
self.temp_cache_info = dict()
|
self.temp_cache_info = dict()
|
||||||
self.current_request_ids = dict()
|
self.current_request_ids = dict()
|
||||||
self.splitwise_queue = splitwise_queue
|
self.splitwise_queue = splitwise_queue
|
||||||
|
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||||
|
|
||||||
if self.cfg.cache_config.pd_comm_port is not None:
|
if self.cfg.cache_config.pd_comm_port is not None:
|
||||||
self.zmq_ctx = zmq.Context()
|
self.zmq_ctx = zmq.Context()
|
||||||
@@ -70,7 +75,7 @@ class SplitwiseConnector:
|
|||||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||||
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
|
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port[0]}")
|
||||||
|
|
||||||
self.poller = zmq.Poller()
|
self.poller = zmq.Poller()
|
||||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||||
@@ -89,16 +94,16 @@ class SplitwiseConnector:
|
|||||||
if not socks:
|
if not socks:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.debug(f"receive {socks}")
|
self.logger.debug(f"receive {socks}")
|
||||||
|
|
||||||
frames = self.router_socket.recv_multipart()
|
frames = self.router_socket.recv_multipart()
|
||||||
logger.debug(f"frames: {frames}")
|
self.logger.debug(f"frames: {frames}")
|
||||||
message = frames[-1]
|
message = frames[-1]
|
||||||
self.io_executor.submit(self._process_message, message)
|
self.io_executor.submit(self._process_message, message)
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Receiver error: {e}")
|
self.logger.error(f"Receiver error: {e}")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def _get_push_socket(self, addr):
|
def _get_push_socket(self, addr):
|
||||||
@@ -110,7 +115,7 @@ class SplitwiseConnector:
|
|||||||
return sock
|
return sock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Establishing new connection to {addr}")
|
self.logger.info(f"Establishing new connection to {addr}")
|
||||||
sock = self.zmq_ctx.socket(zmq.DEALER)
|
sock = self.zmq_ctx.socket(zmq.DEALER)
|
||||||
|
|
||||||
# 设置连接参数
|
# 设置连接参数
|
||||||
@@ -129,7 +134,7 @@ class SplitwiseConnector:
|
|||||||
return sock
|
return sock
|
||||||
|
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
logger.error(f"Connection to {addr} failed: {e}")
|
self.logger.error(f"Connection to {addr} failed: {e}")
|
||||||
|
|
||||||
raise ConnectionError(f"Failed to connect to {addr}") from e
|
raise ConnectionError(f"Failed to connect to {addr}") from e
|
||||||
|
|
||||||
@@ -138,7 +143,7 @@ class SplitwiseConnector:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Sent {msg_type} to {addr}")
|
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||||
message = self._serialize_message(msg_type, payload)
|
message = self._serialize_message(msg_type, payload)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -146,18 +151,18 @@ class SplitwiseConnector:
|
|||||||
sock = self._get_push_socket(addr)
|
sock = self._get_push_socket(addr)
|
||||||
sock.send_multipart([b"", message])
|
sock.send_multipart([b"", message])
|
||||||
|
|
||||||
logger.info(f"Sent {msg_type} to {addr}")
|
self.logger.info(f"Sent {msg_type} to {addr}")
|
||||||
|
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
logger.warning(f"Connection to {addr} not established")
|
self.logger.warning(f"Connection to {addr} not established")
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
logger.warning(f"Send queue full for {addr}")
|
self.logger.warning(f"Send queue full for {addr}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Send to {addr} failed: {e}")
|
self.logger.error(f"Send to {addr} failed: {e}")
|
||||||
self._close_connection(addr)
|
self._close_connection(addr)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Message preparation failed: {e}")
|
self.logger.error(f"Message preparation failed: {e}")
|
||||||
|
|
||||||
def _close_connection(self, addr):
|
def _close_connection(self, addr):
|
||||||
"""
|
"""
|
||||||
@@ -262,7 +267,7 @@ class SplitwiseConnector:
|
|||||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||||
)
|
)
|
||||||
logger.info(f"send splitwise tasks to port {addr} decode")
|
self.logger.info(f"send splitwise tasks to port {addr} decode")
|
||||||
self.current_request_ids[task.request_id] = "init"
|
self.current_request_ids[task.request_id] = "init"
|
||||||
decode_diagg = task.disaggregate_info["cache_info"]
|
decode_diagg = task.disaggregate_info["cache_info"]
|
||||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||||
@@ -290,7 +295,7 @@ class SplitwiseConnector:
|
|||||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||||
logger.info(f"send splitwise tasks to port {port} decode")
|
self.logger.info(f"send splitwise tasks to port {port} decode")
|
||||||
current_port = port
|
current_port = port
|
||||||
return current_port
|
return current_port
|
||||||
|
|
||||||
@@ -300,7 +305,7 @@ class SplitwiseConnector:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(tasks_list, list):
|
if not isinstance(tasks_list, list):
|
||||||
tasks_list = [tasks_list]
|
tasks_list = [tasks_list]
|
||||||
logger.info("send first token to port decode")
|
self.logger.info("send first token to port decode")
|
||||||
if prefill_msg["transfer_protocol"] == "ipc":
|
if prefill_msg["transfer_protocol"] == "ipc":
|
||||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||||
if port not in self.connect_innode_instances:
|
if port not in self.connect_innode_instances:
|
||||||
@@ -308,7 +313,7 @@ class SplitwiseConnector:
|
|||||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||||
else:
|
else:
|
||||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||||
logger.info(f"send first token to port {node} decode")
|
self.logger.info(f"send first token to port {node} decode")
|
||||||
self._send_message(node, "decode", tasks_list)
|
self._send_message(node, "decode", tasks_list)
|
||||||
|
|
||||||
def create_connection(self, port):
|
def create_connection(self, port):
|
||||||
@@ -324,6 +329,22 @@ class SplitwiseConnector:
|
|||||||
client_id=0,
|
client_id=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_decode_allocated(self, task):
|
||||||
|
if task.disaggregate_info is None:
|
||||||
|
return True, ""
|
||||||
|
if self.enable_decode_cache_task:
|
||||||
|
return True, ""
|
||||||
|
if task.disaggregate_info["role"] != "prefill":
|
||||||
|
return True, ""
|
||||||
|
while self.current_request_ids[task.request_id] == "init":
|
||||||
|
time.sleep(0.001)
|
||||||
|
msg = self.current_request_ids[task.request_id]
|
||||||
|
del self.current_request_ids[task.request_id]
|
||||||
|
if msg == "finished":
|
||||||
|
return True, ""
|
||||||
|
self.logger.error(f"Receive_decode_allocated error: {msg}")
|
||||||
|
return False, msg
|
||||||
|
|
||||||
def send_cache_infos(self, tasks, current_id):
|
def send_cache_infos(self, tasks, current_id):
|
||||||
"""
|
"""
|
||||||
Send cache information to specific port.
|
Send cache information to specific port.
|
||||||
@@ -340,15 +361,21 @@ class SplitwiseConnector:
|
|||||||
for i in range(len(tasks)):
|
for i in range(len(tasks)):
|
||||||
if tasks[i].disaggregate_info is None:
|
if tasks[i].disaggregate_info is None:
|
||||||
continue
|
continue
|
||||||
logger.info(f"{tasks[i].disaggregate_info}")
|
self.logger.info(f"{tasks[i].disaggregate_info}")
|
||||||
if tasks[i].disaggregate_info["role"] == "decode":
|
if tasks[i].disaggregate_info["role"] == "decode":
|
||||||
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
|
||||||
cache_info = {
|
if tasks[i].get("error_msg", None) is not None:
|
||||||
"request_id": tasks[i].request_id,
|
cache_info = {
|
||||||
"device_ids": self.cfg.device_ids.split(","),
|
"request_id": tasks[i].request_id,
|
||||||
"transfer_protocol": "ipc",
|
"error_msg": tasks[i].get("error_msg"),
|
||||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
}
|
||||||
}
|
else:
|
||||||
|
cache_info = {
|
||||||
|
"request_id": tasks[i].request_id,
|
||||||
|
"device_ids": self.cfg.device_ids.split(","),
|
||||||
|
"transfer_protocol": "ipc",
|
||||||
|
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||||
|
}
|
||||||
if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info:
|
if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info:
|
||||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = []
|
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = []
|
||||||
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info)
|
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info)
|
||||||
@@ -357,14 +384,20 @@ class SplitwiseConnector:
|
|||||||
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
|
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||||
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
|
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
|
||||||
)
|
)
|
||||||
cache_info = {
|
if tasks[i].get("error_msg", None) is not None:
|
||||||
"request_id": tasks[i].request_id,
|
cache_info = {
|
||||||
"device_ids": self.cfg.device_ids.split(","),
|
"request_id": tasks[i].request_id,
|
||||||
"ip": self.cfg.host_ip,
|
"error_msg": tasks[i].get("error_msg"),
|
||||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
}
|
||||||
"transfer_protocol": "rdma",
|
else:
|
||||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
cache_info = {
|
||||||
}
|
"request_id": tasks[i].request_id,
|
||||||
|
"device_ids": self.cfg.device_ids.split(","),
|
||||||
|
"ip": self.cfg.host_ip,
|
||||||
|
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||||
|
"transfer_protocol": "rdma",
|
||||||
|
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||||
|
}
|
||||||
if addr not in temp_cache_info:
|
if addr not in temp_cache_info:
|
||||||
temp_cache_info[addr] = []
|
temp_cache_info[addr] = []
|
||||||
|
|
||||||
@@ -391,7 +424,7 @@ class SplitwiseConnector:
|
|||||||
else:
|
else:
|
||||||
if len(temp_cache_info):
|
if len(temp_cache_info):
|
||||||
for k, v in temp_cache_info.items():
|
for k, v in temp_cache_info.items():
|
||||||
logger.info(f"{k} {v}")
|
self.logger.info(f"{k} {v}")
|
||||||
if ":" in str(k):
|
if ":" in str(k):
|
||||||
self._send_message(k, "cache_sync", v)
|
self._send_message(k, "cache_sync", v)
|
||||||
else:
|
else:
|
||||||
@@ -408,7 +441,7 @@ class SplitwiseConnector:
|
|||||||
payload = [output.to_dict() for output in payload]
|
payload = [output.to_dict() for output in payload]
|
||||||
|
|
||||||
req_ids = [task["request_id"] for task in payload]
|
req_ids = [task["request_id"] for task in payload]
|
||||||
logger.info(f"send message {msg_type} {req_ids}")
|
self.logger.info(f"send message {msg_type} {req_ids}")
|
||||||
|
|
||||||
json_data = msgpack.packb({"type": msg_type, "payload": payload})
|
json_data = msgpack.packb({"type": msg_type, "payload": payload})
|
||||||
|
|
||||||
@@ -419,7 +452,7 @@ class SplitwiseConnector:
|
|||||||
# JSON反序列化
|
# JSON反序列化
|
||||||
message = msgpack.unpackb(data)
|
message = msgpack.unpackb(data)
|
||||||
req_ids = [task["request_id"] for task in message["payload"]]
|
req_ids = [task["request_id"] for task in message["payload"]]
|
||||||
logger.info(f"send message {message['type']} {req_ids}")
|
self.logger.info(f"recv message type {message['type']} for {req_ids}")
|
||||||
return message["type"], message["payload"]
|
return message["type"], message["payload"]
|
||||||
|
|
||||||
def _process_message(self, message: bytes):
|
def _process_message(self, message: bytes):
|
||||||
@@ -428,7 +461,7 @@ class SplitwiseConnector:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
msg_type, payload = self._deserialize_message(message)
|
msg_type, payload = self._deserialize_message(message)
|
||||||
logger.info(f"{msg_type}")
|
self.logger.info(f"{msg_type}")
|
||||||
|
|
||||||
if msg_type == "prefill":
|
if msg_type == "prefill":
|
||||||
self._handle_prefill(payload)
|
self._handle_prefill(payload)
|
||||||
@@ -436,11 +469,16 @@ class SplitwiseConnector:
|
|||||||
self._handle_decode(payload)
|
self._handle_decode(payload)
|
||||||
elif msg_type == "cache_sync":
|
elif msg_type == "cache_sync":
|
||||||
for task in payload:
|
for task in payload:
|
||||||
del self.current_request_ids[task["request_id"]]
|
self.logger.info(f"cache_sync task: {task}")
|
||||||
self.engine_worker_queue.put_cache_info(payload)
|
current_status = task.get("error_msg", "finished")
|
||||||
|
self.current_request_ids[task["request_id"]] = current_status
|
||||||
|
if self.enable_decode_cache_task:
|
||||||
|
del self.current_request_ids[task["request_id"]]
|
||||||
|
if current_status == "finished":
|
||||||
|
self.engine_worker_queue.put_cache_info(payload)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Message processing failed: {e}")
|
self.logger.error(f"Message processing failed: {e}")
|
||||||
|
|
||||||
def _handle_prefill(self, tasks):
|
def _handle_prefill(self, tasks):
|
||||||
"""
|
"""
|
||||||
@@ -450,7 +488,7 @@ class SplitwiseConnector:
|
|||||||
tasks_data = [Request.from_dict(task) for task in tasks]
|
tasks_data = [Request.from_dict(task) for task in tasks]
|
||||||
req_ids = [task["request_id"] for task in tasks]
|
req_ids = [task["request_id"] for task in tasks]
|
||||||
self.splitwise_queue.append(("decode", tasks_data))
|
self.splitwise_queue.append(("decode", tasks_data))
|
||||||
logger.debug(f"{req_ids} received prefill data")
|
self.logger.info(f"{req_ids} received prefill data")
|
||||||
|
|
||||||
def _handle_decode(self, payload):
|
def _handle_decode(self, payload):
|
||||||
"""
|
"""
|
||||||
@@ -471,4 +509,4 @@ class SplitwiseConnector:
|
|||||||
)
|
)
|
||||||
req_ids = [task["request_id"] for task in payload]
|
req_ids = [task["request_id"] for task in payload]
|
||||||
self.splitwise_queue.append(("decode", tasks))
|
self.splitwise_queue.append(("decode", tasks))
|
||||||
logger.debug(f"{req_ids} received decode data")
|
self.logger.info(f"{req_ids} received decode data")
|
||||||
|
Reference in New Issue
Block a user