diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 17dc539d1..1c3462913 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -17,8 +17,9 @@ import argparse import json import math -import time import threading +import time + import numpy as np import paddle @@ -196,7 +197,9 @@ class CacheMessager: self.gpu_id = gpu_id self.cache_info = dict() - self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch) + self.rank_id = ( + self.rank + local_data_parallel_id * self.nranks + ) # align with engine worker rank (paddle.distributed.launch) connect_rdma_thread = threading.Thread(target=self._handle_connect_task) connect_rdma_thread.daemon = True @@ -284,7 +287,7 @@ class CacheMessager: if not self.cache_info: time.sleep(0.001) continue - logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") + logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") for req_id, item in list(self.cache_info.items()): if "status" not in item: continue @@ -364,7 +367,7 @@ class CacheMessager: except Exception as e: logger.info(f"prefill layerwise send cache thread has exception: {e}") - + def _handle_connect_task(self): while True: try: @@ -465,7 +468,8 @@ def main(): if __name__ == "__main__": args = parse_args() - logger = get_logger("cache_messager", "cache_messager.log") + rank_id = args.rank + args.local_data_parallel_id * args.mp_num + logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log") logger.info("create cache messager...") logger.info(f"{args}") diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 51f17a454..fbd4005c9 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -113,6 +113,8 @@ class LLMEngine: self.start_queue_service() + self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" + if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager = ResourceManagerV1( cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role @@ -630,11 +632,15 @@ class LLMEngine: if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.insert_tasks([task]) else: + if not self.enable_decode_cache_task: + task.error_msg = "Not enough resources" new_waiting.append(task) - if new_waiting: - self.waiting_requests.extend(new_waiting) - llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") + if not self.enable_decode_cache_task: + self.split_connector.send_cache_infos(new_waiting, -1) + else: + self.waiting_requests.extend(new_waiting) + llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") else: time.sleep(0.001) @@ -805,6 +811,22 @@ class LLMEngine: for task in tasks: start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) + if self.cfg.splitwise_role != "mixed": + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=msg, + ) + ] + ) + tasks.remove(task) + continue if task.sampling_params.bad_words is not None: task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer) @@ -1020,7 +1042,6 @@ class LLMEngine: except Exception as e: print(f"Error extracting sub services: {e}") - for worker_queue in self.engine_worker_queue_server: worker_queue.cleanup() if hasattr(self, "send_response_server") and self.send_response_server is not None: diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 3892deca9..6013efccf 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -26,6 +26,7 @@ from collections import deque import numpy as np +from fastdeploy.engine.request import RequestOutput from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.metrics.metrics import main_process_metrics @@ -34,6 +35,7 @@ from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger + class ExpertService: """ Engine class responsible for managing the Large Language Model (LLM) operations. @@ -146,7 +148,7 @@ class ExpertService: # Start TokenProcessor thread os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port)) - + self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK self.token_processor.run() self.cfg.init_cache_info() @@ -262,11 +264,15 @@ class ExpertService: if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.insert_tasks([task]) else: + if not self.enable_decode_cache_task: + task.error_msg = "Not enough resources" new_waiting.append(task) - if new_waiting: - self.waiting_requests.extend(new_waiting) - self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") + if not self.enable_decode_cache_task: + self.split_connector.send_cache_infos(new_waiting, -1) + else: + self.waiting_requests.extend(new_waiting) + self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") else: time.sleep(0.001) @@ -310,8 +316,24 @@ class ExpertService: if not isinstance(tasks, list): tasks = [tasks] - for item in tasks: - item.schedule_start_time = time.time() + for task in tasks: + if self.cfg.splitwise_role != "mixed": + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=msg, + ) + ] + ) + tasks.remove(task) + continue + task.schedule_start_time = time.time() available_batch = np.sum(self.resource_manager.stop_flags) if len(tasks) > available_batch: diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index dd95de1df..25c4b0f83 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -90,6 +90,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), # Whether to use PLUGINS. "FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","), + # Whether to enable cache task in decode node + "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"), } diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index c5ac6534c..619fa1410 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -26,8 +26,6 @@ from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.utils import get_logger -logger = get_logger("splitwise_connector", "splitwise_connector.log") - class SplitwiseConnector: """ @@ -45,6 +43,12 @@ class SplitwiseConnector: resource_manager (object): Resource manager object. """ self.cfg = cfg + if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: + self.logger = get_logger( + "splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log" + ) + else: + self.logger = get_logger("splitwise_connector", "splitwise_connector.log") self.scheduler = scheduler self.engine_worker_queue = worker_queue self.resource_manager = resource_manager @@ -52,6 +56,7 @@ class SplitwiseConnector: self.temp_cache_info = dict() self.current_request_ids = dict() self.splitwise_queue = splitwise_queue + self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" if self.cfg.cache_config.pd_comm_port is not None: self.zmq_ctx = zmq.Context() @@ -70,7 +75,7 @@ class SplitwiseConnector: self.router_socket.setsockopt(zmq.SNDHWM, 1000) self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}") - logger.info(f"bind {self.cfg.cache_config.pd_comm_port}") + self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port[0]}") self.poller = zmq.Poller() self.poller.register(self.router_socket, zmq.POLLIN) @@ -89,16 +94,16 @@ class SplitwiseConnector: if not socks: continue else: - logger.debug(f"receive {socks}") + self.logger.debug(f"receive {socks}") frames = self.router_socket.recv_multipart() - logger.debug(f"frames: {frames}") + self.logger.debug(f"frames: {frames}") message = frames[-1] self.io_executor.submit(self._process_message, message) time.sleep(0.001) except Exception as e: - logger.error(f"Receiver error: {e}") + self.logger.error(f"Receiver error: {e}") time.sleep(1) def _get_push_socket(self, addr): @@ -110,7 +115,7 @@ class SplitwiseConnector: return sock try: - logger.info(f"Establishing new connection to {addr}") + self.logger.info(f"Establishing new connection to {addr}") sock = self.zmq_ctx.socket(zmq.DEALER) # 设置连接参数 @@ -129,7 +134,7 @@ class SplitwiseConnector: return sock except zmq.ZMQError as e: - logger.error(f"Connection to {addr} failed: {e}") + self.logger.error(f"Connection to {addr} failed: {e}") raise ConnectionError(f"Failed to connect to {addr}") from e @@ -138,7 +143,7 @@ class SplitwiseConnector: return try: - logger.info(f"Sent {msg_type} to {addr}") + self.logger.info(f"Sent {msg_type} to {addr}") message = self._serialize_message(msg_type, payload) try: @@ -146,18 +151,18 @@ class SplitwiseConnector: sock = self._get_push_socket(addr) sock.send_multipart([b"", message]) - logger.info(f"Sent {msg_type} to {addr}") + self.logger.info(f"Sent {msg_type} to {addr}") except ConnectionError: - logger.warning(f"Connection to {addr} not established") + self.logger.warning(f"Connection to {addr} not established") except zmq.Again: - logger.warning(f"Send queue full for {addr}") + self.logger.warning(f"Send queue full for {addr}") except Exception as e: - logger.error(f"Send to {addr} failed: {e}") + self.logger.error(f"Send to {addr} failed: {e}") self._close_connection(addr) except Exception as e: - logger.error(f"Message preparation failed: {e}") + self.logger.error(f"Message preparation failed: {e}") def _close_connection(self, addr): """ @@ -262,7 +267,7 @@ class SplitwiseConnector: f"{task.disaggregate_info['cache_info']['rdma']['ip']}:" + f"{task.disaggregate_info['cache_info']['rdma']['port']}" ) - logger.info(f"send splitwise tasks to port {addr} decode") + self.logger.info(f"send splitwise tasks to port {addr} decode") self.current_request_ids[task.request_id] = "init" decode_diagg = task.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] @@ -290,7 +295,7 @@ class SplitwiseConnector: self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks)) for task in tasks: task.disaggregate_info["cache_info"]["ipc"]["port"] = port - logger.info(f"send splitwise tasks to port {port} decode") + self.logger.info(f"send splitwise tasks to port {port} decode") current_port = port return current_port @@ -300,7 +305,7 @@ class SplitwiseConnector: """ if not isinstance(tasks_list, list): tasks_list = [tasks_list] - logger.info("send first token to port decode") + self.logger.info("send first token to port decode") if prefill_msg["transfer_protocol"] == "ipc": port = prefill_msg["cache_info"]["ipc"]["port"] if port not in self.connect_innode_instances: @@ -308,7 +313,7 @@ class SplitwiseConnector: self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list)) else: node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}" - logger.info(f"send first token to port {node} decode") + self.logger.info(f"send first token to port {node} decode") self._send_message(node, "decode", tasks_list) def create_connection(self, port): @@ -324,6 +329,22 @@ class SplitwiseConnector: client_id=0, ) + def check_decode_allocated(self, task): + if task.disaggregate_info is None: + return True, "" + if self.enable_decode_cache_task: + return True, "" + if task.disaggregate_info["role"] != "prefill": + return True, "" + while self.current_request_ids[task.request_id] == "init": + time.sleep(0.001) + msg = self.current_request_ids[task.request_id] + del self.current_request_ids[task.request_id] + if msg == "finished": + return True, "" + self.logger.error(f"Receive_decode_allocated error: {msg}") + return False, msg + def send_cache_infos(self, tasks, current_id): """ Send cache information to specific port. @@ -340,15 +361,21 @@ class SplitwiseConnector: for i in range(len(tasks)): if tasks[i].disaggregate_info is None: continue - logger.info(f"{tasks[i].disaggregate_info}") + self.logger.info(f"{tasks[i].disaggregate_info}") if tasks[i].disaggregate_info["role"] == "decode": if tasks[i].disaggregate_info["transfer_protocol"] == "ipc": - cache_info = { - "request_id": tasks[i].request_id, - "device_ids": self.cfg.device_ids.split(","), - "transfer_protocol": "ipc", - "dest_block_ids": tasks[i].disaggregate_info["block_tables"], - } + if tasks[i].get("error_msg", None) is not None: + cache_info = { + "request_id": tasks[i].request_id, + "error_msg": tasks[i].get("error_msg"), + } + else: + cache_info = { + "request_id": tasks[i].request_id, + "device_ids": self.cfg.device_ids.split(","), + "transfer_protocol": "ipc", + "dest_block_ids": tasks[i].disaggregate_info["block_tables"], + } if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info: temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = [] temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info) @@ -357,14 +384,20 @@ class SplitwiseConnector: f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:" + f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}" ) - cache_info = { - "request_id": tasks[i].request_id, - "device_ids": self.cfg.device_ids.split(","), - "ip": self.cfg.host_ip, - "rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"], - "transfer_protocol": "rdma", - "dest_block_ids": tasks[i].disaggregate_info["block_tables"], - } + if tasks[i].get("error_msg", None) is not None: + cache_info = { + "request_id": tasks[i].request_id, + "error_msg": tasks[i].get("error_msg"), + } + else: + cache_info = { + "request_id": tasks[i].request_id, + "device_ids": self.cfg.device_ids.split(","), + "ip": self.cfg.host_ip, + "rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"], + "transfer_protocol": "rdma", + "dest_block_ids": tasks[i].disaggregate_info["block_tables"], + } if addr not in temp_cache_info: temp_cache_info[addr] = [] @@ -391,7 +424,7 @@ class SplitwiseConnector: else: if len(temp_cache_info): for k, v in temp_cache_info.items(): - logger.info(f"{k} {v}") + self.logger.info(f"{k} {v}") if ":" in str(k): self._send_message(k, "cache_sync", v) else: @@ -408,7 +441,7 @@ class SplitwiseConnector: payload = [output.to_dict() for output in payload] req_ids = [task["request_id"] for task in payload] - logger.info(f"send message {msg_type} {req_ids}") + self.logger.info(f"send message {msg_type} {req_ids}") json_data = msgpack.packb({"type": msg_type, "payload": payload}) @@ -419,7 +452,7 @@ class SplitwiseConnector: # JSON反序列化 message = msgpack.unpackb(data) req_ids = [task["request_id"] for task in message["payload"]] - logger.info(f"send message {message['type']} {req_ids}") + self.logger.info(f"recv message type {message['type']} for {req_ids}") return message["type"], message["payload"] def _process_message(self, message: bytes): @@ -428,7 +461,7 @@ class SplitwiseConnector: """ try: msg_type, payload = self._deserialize_message(message) - logger.info(f"{msg_type}") + self.logger.info(f"{msg_type}") if msg_type == "prefill": self._handle_prefill(payload) @@ -436,11 +469,16 @@ class SplitwiseConnector: self._handle_decode(payload) elif msg_type == "cache_sync": for task in payload: - del self.current_request_ids[task["request_id"]] - self.engine_worker_queue.put_cache_info(payload) + self.logger.info(f"cache_sync task: {task}") + current_status = task.get("error_msg", "finished") + self.current_request_ids[task["request_id"]] = current_status + if self.enable_decode_cache_task: + del self.current_request_ids[task["request_id"]] + if current_status == "finished": + self.engine_worker_queue.put_cache_info(payload) except Exception as e: - logger.error(f"Message processing failed: {e}") + self.logger.error(f"Message processing failed: {e}") def _handle_prefill(self, tasks): """ @@ -450,7 +488,7 @@ class SplitwiseConnector: tasks_data = [Request.from_dict(task) for task in tasks] req_ids = [task["request_id"] for task in tasks] self.splitwise_queue.append(("decode", tasks_data)) - logger.debug(f"{req_ids} received prefill data") + self.logger.info(f"{req_ids} received prefill data") def _handle_decode(self, payload): """ @@ -471,4 +509,4 @@ class SplitwiseConnector: ) req_ids = [task["request_id"] for task in payload] self.splitwise_queue.append(("decode", tasks)) - logger.debug(f"{req_ids} received decode data") + self.logger.info(f"{req_ids} received decode data")