diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 25a3b50e6..b70e27620 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -96,6 +96,19 @@ def parse_args(): return args +def get_decode_ip_idx(task): + """For compatibility, get decode ip and idx from task""" + if "decode_ip" in task: + decode_ip = task["decode_ip"] + else: + decode_ip = task["ip"] + if "decode_rdma_ports" in task: + decode_rdma_ports = task["decode_rdma_ports"] + else: + decode_rdma_ports = task["rdma_ports"] + return decode_ip, decode_rdma_ports + + class CacheMessager: """ CacheMessager is used to send the cache data between the engine worker and the cache server. @@ -282,6 +295,7 @@ class CacheMessager: self.cache_info[info["request_id"]] = current_info else: self.cache_info[info["request_id"]] = info + prefilled_layer_idx = layer_shm_value.value[0] prefilled_step_idx = step_shm_value.value[0] if prefilled_layer_idx == self.num_layers - 1: @@ -316,15 +330,18 @@ class CacheMessager: continue current_transfer_protocol = item["transfer_protocol"] if item["transfer_protocol"] == "rdma": - target_ip = item["ip"] - target_id = int(item["rdma_ports"][self.rank]) - status = self.messager[current_transfer_protocol].connect(target_ip, target_id) + decode_ip, decode_rdma_ports = get_decode_ip_idx(item) + decode_idx = int(decode_rdma_ports[self.rank]) + status = self.messager[current_transfer_protocol].connect(decode_ip, decode_idx) if not status: - logger.error(f"connect to {target_ip}:{target_id} failed") + logger.error(f"connect to {decode_ip}:{decode_idx} failed") item["status"] = "connect error" elif item["transfer_protocol"] == "ipc": - target_ip = "0.0.0.0" - target_id = int(item["device_ids"][self.rank]) + decode_ip = "0.0.0.0" + decode_device_ids = ( + item["decode_device_ids"] if "decode_device_ids" in item else item["device_ids"] + ) + decode_idx = int(decode_device_ids[self.rank]) src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu") dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu") if item["current_id"] < prefilled_step_idx: @@ -335,8 +352,8 @@ class CacheMessager: for layer_idx in range(item["layer_idx"], current_layer_idx): tic = time.time() return_code = self.messager[current_transfer_protocol].write_cache( - target_ip, - target_id, + decode_ip, + decode_idx, src_block_ids, dest_block_ids, layer_idx, @@ -345,7 +362,7 @@ class CacheMessager: item["status"] = "write cache error" logger.error( f"write cache failed, layer_idx: {layer_idx}, " - f"req_id: {item['request_id']}, dest_ip: {target_ip}" + f"req_id: {item['request_id']}, dest_ip: {decode_ip}" ) break @@ -365,7 +382,7 @@ class CacheMessager: if "error" not in item["status"]: item["status"] = "finished" if item["transfer_protocol"] == "ipc": - self.messager["ipc"].write_block_by_sync(target_id) + self.messager["ipc"].write_block_by_sync(decode_idx) logger.info(f"finish write cache {item['request_id']}") self.engine_worker_queue.finish_send_cache_barrier.wait() self.engine_worker_queue.put_finished_req([[item["request_id"], item["status"]]]) @@ -387,8 +404,9 @@ class CacheMessager: self.engine_worker_queue.connect_task_barrier.wait() logger.info(f"_handle_connect_task recv task: {task}") task_id = task["task_id"] - ip, rdma_port = task["ip"], task["rdma_ports"][self.rank] - status = self.messager["rdma"].connect(ip, rdma_port) + decode_ip, decode_rdma_ports = get_decode_ip_idx(task) + rdma_port = decode_rdma_ports[self.rank] + status = self.messager["rdma"].connect(decode_ip, rdma_port) if not status: response = {"task_id": task_id, "success": False} else: @@ -634,6 +652,7 @@ class CacheMessagerV1: end_layer_idx = prefilled_layer_idx if sended_layer_idx == prefilled_layer_idx: # computation not in next layer time.sleep(0.01) + for layer_idx in range(start_layer_idx, end_layer_idx + 1): for i, (block_id_start, block_id_end) in enumerate(block_start_end_list): engine_index = batch_engine_signals[i][0] @@ -650,13 +669,13 @@ class CacheMessagerV1: else: current_transfer_protocol = task["transfer_protocol"] if task["transfer_protocol"] == "rdma": - target_ip = task["ip"] + decode_ip, decode_rdma_ports = get_decode_ip_idx(task) # Default decode_tp_size to prefill tp_size (self.nranks) if not specified decode_tp_size = task.get("decode_tp_size", self.nranks) - if len(task["rdma_ports"]) == self.nranks: - target_id = int(task["rdma_ports"][self.rank]) - elif len(task["rdma_ports"]) == 1: - target_id = task["rdma_ports"][0] + if len(decode_rdma_ports) == self.nranks: + decode_idx = int(decode_rdma_ports[self.rank]) + elif len(decode_rdma_ports) == 1: + decode_idx = decode_rdma_ports[0] else: task["status"] = "the tp_size of prefill and decode is mismatch" continue @@ -666,21 +685,26 @@ class CacheMessagerV1: # TODO: use is connected to check if the connection is still alive logger.debug( - f"rdma, start connect decode, {target_ip}:{target_id}, " + f"rdma, start connect decode, {decode_ip}:{decode_idx}, " f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}" ) status = self.messager[current_transfer_protocol].connect( - target_ip, target_id, decode_tp_size + decode_ip, decode_idx, decode_tp_size ) if status: - logger.debug(f"connect to {target_ip}:{target_id} success") + logger.debug(f"connect to {decode_ip}:{decode_idx} success") else: - logger.error(f"connect to {target_ip}:{target_id} failed") + logger.error(f"connect to {decode_ip}:{decode_idx} failed") task["status"] = "connection error" continue elif task["transfer_protocol"] == "ipc": - target_ip = "0.0.0.0" - target_id = int(task["device_ids"][self.rank]) + decode_device_ids = ( + task["decode_device_ids"] + if "decode_device_ids" in task + else task["device_ids"] + ) + decode_ip = "0.0.0.0" + decode_idx = int(decode_device_ids[self.rank]) src_block_ids = task["src_block_ids"][block_id_start:block_id_end] dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end] @@ -688,12 +712,12 @@ class CacheMessagerV1: dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu") logger.info( - f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}" + f"start write cache for a layer, {req_id}, {layer_idx}, {decode_ip}, {decode_idx}, block_id_start {block_id_start} block_id_end {block_id_end}" ) tic = time.time() return_code = self.messager[current_transfer_protocol].write_cache( - target_ip, - target_id, + decode_ip, + decode_idx, src_block_ids, dest_block_ids, layer_idx, @@ -701,7 +725,7 @@ class CacheMessagerV1: if return_code != 0: task["status"] = "write cache error" logger.error( - f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}" + f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {decode_ip}, block_id_start {block_id_start} block_id_end {block_id_end}" ) tok = time.time() cost_time = tok - tic @@ -709,7 +733,7 @@ class CacheMessagerV1: avg_time_per_block = cost_time * 1000 / block_num # ms send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s logger.debug( - f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}," + f"finish write cache for a layer, {req_id}, {layer_idx}, {decode_ip}, {decode_idx}," f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)}," f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}" ) @@ -734,8 +758,13 @@ class CacheMessagerV1: task = self.idx_cache_task_dict[engine_idx] if task["status"] == "finished" or ("error" in task["status"]): if task["transfer_protocol"] == "ipc": - target_id = int(task["device_ids"][self.rank]) - self.messager["ipc"].write_block_by_sync(target_id) + decode_device_ids = ( + task["decode_device_ids"] + if "decode_device_ids" in task + else task["device_ids"] + ) + decode_idx = int(decode_device_ids[self.rank]) + self.messager["ipc"].write_block_by_sync(decode_idx) self.engine_worker_queue.finish_send_cache_barrier.wait() self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]]) logger.info( @@ -796,18 +825,17 @@ class CacheMessagerV1: self.engine_worker_queue.connect_task_barrier.wait() logger.info(f"_handle_connect_task recv task: {task}") task_id = task["task_id"] - ip = task["ip"] + decode_ip, decode_rdma_ports = get_decode_ip_idx(task) # Default decode_tp_size to self.nranks (number of ranks) if not specified in the task. decode_tp_size = task.get("decode_tp_size", self.nranks) - rdma_ports = task["rdma_ports"] - rdma_ports_len = len(rdma_ports) + rdma_ports_len = len(decode_rdma_ports) if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks): # TODO: support other cases logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}") response = {"task_id": task_id, "success": False} else: - port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank] - status = self.messager["rdma"].connect(ip, port, decode_tp_size) + port = decode_rdma_ports[0] if rdma_ports_len == 1 else decode_rdma_ports[self.rank] + status = self.messager["rdma"].connect(decode_ip, port, decode_tp_size) if not status: response = {"task_id": task_id, "success": False} else: diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 11867a798..24b26161a 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1919,42 +1919,23 @@ class FDConfig: else None ) - self.disaggregate_info = {} - if self.scheduler_config.splitwise_role != "mixed": - self.disaggregate_info["role"] = self.scheduler_config.splitwise_role - self.disaggregate_info["cache_info"] = dict() - current_protocol = self.cache_config.cache_transfer_protocol.split(",") - self.disaggregate_info["transfer_protocol"] = current_protocol - - for protocol in current_protocol: - if protocol == "ipc": - self.disaggregate_info["cache_info"][protocol] = { - "ip": self.host_ip, - "port": engine_worker_queue_port, - "device_ids": self.local_device_ids, - } - elif protocol == "rdma": - self.disaggregate_info["cache_info"][protocol] = { - "ip": self.host_ip, - "port": connector_port, - "rdma_port": self.cache_config.rdma_comm_ports, - } - logger.info(f"disaggregate_info: {self.disaggregate_info}") - - if self.router_config: - # the information for registering this server to router - self.register_info = { - "role": self.scheduler_config.splitwise_role, - "host_ip": self.host_ip, - "port": self.router_config.api_server_port, - "connector_port": connector_port, - "rdma_ports": self.cache_config.rdma_comm_ports, - "engine_worker_queue_port": engine_worker_queue_port, - "device_ids": self.local_device_ids, - "transfer_protocol": self.cache_config.cache_transfer_protocol.split(","), - "tp_size": self.parallel_config.tensor_parallel_size, - } - logger.info(f"register_info: {self.register_info}") + # the information for registering this server to router or splitwise_scheduler + port = self.router_config.api_server_port if self.router_config else None + transfer_protocol = ( + self.cache_config.cache_transfer_protocol.split(",") if self.cache_config.cache_transfer_protocol else [] + ) + self.register_info = { + "role": self.scheduler_config.splitwise_role, + "host_ip": self.host_ip, + "port": port, + "connector_port": connector_port, + "rdma_ports": self.cache_config.rdma_comm_ports, + "engine_worker_queue_port": engine_worker_queue_port, + "device_ids": self.local_device_ids, + "transfer_protocol": transfer_protocol, + "tp_size": self.parallel_config.tensor_parallel_size, + } + logger.info(f"register_info: {self.register_info}") def read_from_config(self): """ diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index dc9741270..fde2ba204 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -424,7 +424,7 @@ class EngineService: need_delete_tasks = [] for task in tasks: - if self.cfg.scheduler_config.splitwise_role != "mixed": + if self.cfg.scheduler_config.splitwise_role == "prefill": status, msg = self.split_connector.check_decode_allocated(task) if status: task.metrics.ask_decode_resource_finish_time = time.time() @@ -469,7 +469,7 @@ class EngineService: is_prefill = False for i in range(len(tasks)): if tasks[i].disaggregate_info is not None: - if tasks[i].disaggregate_info["role"] == "decode": + if self.cfg.scheduler_config.splitwise_role == "decode": is_decode = True else: is_prefill = True @@ -811,11 +811,10 @@ class EngineService: f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" ) - if self.cfg.scheduler_config.splitwise_role != "mixed": - if self.cfg.scheduler_config.splitwise_role == "prefill": - for task in tasks: - # start async preprocess - self.resource_manager.apply_async_preprocess(task) + if self.cfg.scheduler_config.splitwise_role == "prefill": + for task in tasks: + # start async preprocess + self.resource_manager.apply_async_preprocess(task) need_delete_tasks = [] if envs.FD_OFFLINE_PERF_TEST_FOR_PD: for task in tasks: @@ -873,7 +872,6 @@ class EngineService: # release resource in P self.resource_manager.pre_recycle_resource(tmp_task.request_id) - if self.cfg.scheduler_config.splitwise_role == "prefill": # to send cache info to cache messager if tasks: need_check_req_ids = [task.request_id for task in tasks] @@ -912,6 +910,7 @@ class EngineService: tasks.remove(tmp_task) # release resource in P self.resource_manager.pre_recycle_resource(tmp_task.request_id) + # Fetch requests and add them to the scheduling queue if tasks: for task in tasks: @@ -1765,11 +1764,10 @@ class EngineService: role = self.cfg.scheduler_config.splitwise_role host_ip = self.cfg.host_ip - disaggregate = self.cfg.disaggregate_info request_queues_for_dp_ipc = None result_queue_for_dp_ipc = None if self.cfg.scheduler_config.name == "splitwise": - self.scheduler.start(role, host_ip, disaggregate) + self.scheduler.start(role, host_ip, self.cfg.register_info) elif self.cfg.scheduler_config.name == "dp": request_queues_for_dp_ipc = [] result_queue_for_dp_ipc = multiprocessing.Queue() diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 92b91da8a..be26b231d 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -715,11 +715,10 @@ class LLMEngine: role = self.cfg.scheduler_config.splitwise_role host_ip = self.cfg.host_ip - disaggregate = self.cfg.disaggregate_info request_queues_for_dp_ipc = None result_queues_for_dp_ipc = None if self.cfg.scheduler_config.name == "splitwise": - self.engine.scheduler.start(role, host_ip, disaggregate) + self.engine.scheduler.start(role, host_ip, self.cfg.register_info) elif self.cfg.scheduler_config.name == "dp": request_queues_for_dp_ipc = [] result_queues_for_dp_ipc = [] diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 3b8c40cca..b99b27e5c 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -113,8 +113,7 @@ class ExpertService: self.cfg.init_cache_info() role = self.cfg.scheduler_config.splitwise_role host_ip = self.cfg.host_ip - disaggregate = self.cfg.disaggregate_info - self.engine.scheduler.start(role, host_ip, disaggregate) + self.engine.scheduler.start(role, host_ip, self.cfg.register_info) if self.cfg.scheduler_config.splitwise_role != "mixed": self.splitwise_receive_thread = threading.Thread( diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index 4062dedbb..32acf09b5 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -188,26 +188,15 @@ class Router: is_same_tp_size = prefill_server.tp_size == decode_server.tp_size use_ipc = is_same_node and is_support_ipc and is_same_tp_size - cache_info = {} - if use_ipc: - cache_info["ipc"] = { - "ip": decode_server.host_ip, - "port": decode_server.engine_worker_queue_port, - "device_ids": decode_server.device_ids, - } - else: - cache_info["rdma"] = { - "ip": decode_server.host_ip, - "port": decode_server.connector_port, - "rdma_port": decode_server.rdma_ports, - } - disaggregate_info = { - "prefill": prefill_server.to_dict(), - "decode": decode_server.to_dict(), - "role": "decode", - "cache_info": cache_info, + "prefill_ip": prefill_server.host_ip, + "decode_ip": decode_server.host_ip, + "prefill_connector_port": prefill_server.connector_port, + "decode_connector_port": decode_server.connector_port, + "decode_device_ids": decode_server.device_ids, + "decode_rdma_ports": decode_server.rdma_ports, "transfer_protocol": "ipc" if use_ipc else "rdma", + "decode_tp_size": decode_server.tp_size, } modified_request = request_data.copy() diff --git a/fastdeploy/scheduler/splitwise_scheduler.py b/fastdeploy/scheduler/splitwise_scheduler.py index 350fbf173..c19b3de54 100644 --- a/fastdeploy/scheduler/splitwise_scheduler.py +++ b/fastdeploy/scheduler/splitwise_scheduler.py @@ -14,7 +14,6 @@ # limitations under the License. """ -import copy import hashlib import math import pickle @@ -533,16 +532,26 @@ class APIScheduler: else: dnodes.sort() dnode = self.select_pd(req, dnodes, "decode") - disaggregated = copy.deepcopy(dnode.disaggregated) - transfer_protocol = disaggregated["transfer_protocol"] - if len(transfer_protocol) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol: - if pnode.host == dnode.host: - disaggregated["transfer_protocol"] = "ipc" - else: - disaggregated["transfer_protocol"] = "rdma" - else: - disaggregated["transfer_protocol"] = transfer_protocol[0] - req.disaggregate_info = disaggregated + + is_same_node = pnode.disaggregated["host_ip"] == dnode.disaggregated["host_ip"] + is_support_ipc = ( + "ipc" in pnode.disaggregated["transfer_protocol"] and "ipc" in dnode.disaggregated["transfer_protocol"] + ) + is_same_tp_size = pnode.disaggregated["tp_size"] == dnode.disaggregated["tp_size"] + use_ipc = is_same_node and is_support_ipc and is_same_tp_size + + disaggregate_info = { + "prefill_ip": pnode.disaggregated["host_ip"], + "decode_ip": dnode.disaggregated["host_ip"], + "prefill_connector_port": pnode.disaggregated["connector_port"], + "decode_connector_port": dnode.disaggregated["connector_port"], + "decode_device_ids": dnode.disaggregated["device_ids"], + "decode_rdma_ports": dnode.disaggregated["rdma_ports"], + "transfer_protocol": "ipc" if use_ipc else "rdma", + "decode_tp_size": dnode.disaggregated["tp_size"], + } + + req.disaggregate_info = disaggregate_info pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}" req_dict = req.to_dict() req_dict["group"] = group diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 3aafe3dbe..d58e00a62 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -24,7 +24,6 @@ import zmq from fastdeploy import envs from fastdeploy.engine.request import Request, RequestOutput -from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import get_logger @@ -53,7 +52,6 @@ class SplitwiseConnector: self.logger = get_logger("splitwise_connector", "splitwise_connector.log") self.engine_worker_queue = worker_queue self.resource_manager = resource_manager - self.connect_innode_instances = {} self.current_request_ids = dict() self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" @@ -172,119 +170,59 @@ class SplitwiseConnector: def send_splitwise_tasks(self, tasks: List[Request], current_id): """ - Send splitwise tasks to all connected addresses. + Prefill send splitwise tasks to decode. Parameters: tasks (list): List of tasks. current_id (int): Current ID. """ - addr = None - decode_diagg = None for task in tasks: if task.disaggregate_info is None: continue - if task.disaggregate_info["transfer_protocol"] == "ipc": - addr = task.disaggregate_info["cache_info"]["ipc"]["port"] - task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id - self.logger.info(f"send_splitwise_tasks: protocol=ipc, addr={addr}, task={task.request_id}") - self.send_splitwise_tasks_innode([task], addr) - else: + self.current_request_ids[task.request_id] = "init" + task.disaggregate_info["role"] = "decode" + addr = f"{task.disaggregate_info['decode_ip']}:{task.disaggregate_info['decode_connector_port']}" + self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}") + self._send_message(addr, "prefill", [task]) - addr = ( - f"{task.disaggregate_info['cache_info']['rdma']['ip']}:" - + f"{task.disaggregate_info['cache_info']['rdma']['port']}" - ) - self.current_request_ids[task.request_id] = "init" - decode_diagg = task.disaggregate_info["cache_info"] - task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] - task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id - task.disaggregate_info["role"] = "decode" - self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}") - self._send_message(addr, "prefill", [task]) - task.disaggregate_info["cache_info"] = decode_diagg task.disaggregate_info["role"] = "prefill" - def send_splitwise_tasks_innode(self, tasks, port): - """ - Send splitwise tasks to specific port. - - Parameters: - tasks (list): List of tasks. - port (int): Port number. - - Returns: - int: Current port number, -1 if tasks are not sent. - """ - current_port = -1 - if port not in self.connect_innode_instances: - self.create_connection(port) - for task in tasks: - task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[ - self.local_data_parallel_id - ] - self.logger.info(f"send_splitwise_tasks_innode: port={port}, tasks={[task.request_id for task in tasks]}") - self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks)) - for task in tasks: - task.disaggregate_info["cache_info"]["ipc"]["port"] = port - current_port = port - return current_port - def send_first_token(self, prefill_msg, tasks_list): """ - send first token to specific port + Prefill send first token to specific port """ if not isinstance(tasks_list, list): tasks_list = [tasks_list] - self.logger.info(f"send_first_token: send first token to decode, {[x.request_id for x in tasks_list]}") - if prefill_msg["transfer_protocol"] == "ipc": - port = prefill_msg["cache_info"]["ipc"]["port"] - if port not in self.connect_innode_instances: - self.create_connection(port) - self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list)) - else: - node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}" - self.logger.info(f"send_first_token: send first token to port {node} decode") - self._send_message(node, "decode", tasks_list) - def create_connection(self, port): - """ - Create a connection to specific port. - - Parameters: - port (int): Port number. - """ - if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: - address = ("0.0.0.0", int(port)) - else: - address = f"/dev/shm/fd_task_queue_{port}.sock" - - self.connect_innode_instances[port] = EngineWorkerQueue( - address=address, - num_client=self.cfg.parallel_config.tensor_parallel_size, - client_id=0, + addr = f"{prefill_msg['decode_ip']}:{prefill_msg['decode_connector_port']}" + self.logger.info( + f"send_first_token: send first token to decode ({addr}), {[x.request_id for x in tasks_list]}" ) + self._send_message(addr, "decode", tasks_list) def check_decode_allocated(self, task): - self.logger.debug(f"start check decode allocated: {task.request_id}") + """Check whether the requests have been allocated resources in decode.""" + self.logger.debug(f"check_decode_allocated: {task.request_id}") start_time = time.time() if task.disaggregate_info is None: return True, "" if self.enable_decode_cache_task: return True, "" - if task.disaggregate_info["role"] != "prefill": - return True, "" + while self.current_request_ids[task.request_id] == "init": time.sleep(0.001) if time.time() - start_time > envs.FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS: del self.current_request_ids[task.request_id] return False, "timeout" + msg = self.current_request_ids[task.request_id] del self.current_request_ids[task.request_id] if msg == "finished": return True, "" - self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}") - return False, msg + else: + self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}") + return False, msg def send_cache_info_to_messager(self, tasks: List[Request], current_id): """ @@ -308,13 +246,12 @@ class SplitwiseConnector: "need_prefill_tokens": tasks[i].need_prefill_tokens, } else: - if current_id == -1: - current_id = dsg_info["cache_info"]["ipc"]["current_id"] info = { "request_id": tasks[i].request_id, "src_block_ids": tasks[i].block_tables, "current_id": current_id, } + info.update(dsg_info) cache_info.append(info) self.logger.debug(f"send_cache_info_to_messager, {cache_info}") @@ -333,56 +270,29 @@ class SplitwiseConnector: if dsg_info is None: self.logger.debug(f"skip send_cache_infos_to_prefill, {tasks[i].request_id}") continue - self.logger.debug(f"send_cache_infos_to_prefill, {dsg_info}") - if dsg_info["transfer_protocol"] == "ipc": + if tasks[i].get("error_msg", None) is not None: + info = { + "request_id": tasks[i].request_id, + "error_msg": tasks[i].get("error_msg"), + } + else: + addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" info = { "request_id": tasks[i].request_id, - "device_ids": self.cfg.parallel_config.device_ids.split(","), - "transfer_protocol": "ipc", "dest_block_ids": dsg_info["block_tables"], } - if dsg_info["cache_info"]["ipc"]["port"] not in cache_info: - cache_info[dsg_info["cache_info"]["ipc"]["port"]] = [] - cache_info[dsg_info["cache_info"]["ipc"]["port"]].append(info) - else: - if tasks[i].get("error_msg", None) is not None: - info = { - "request_id": tasks[i].request_id, - "error_msg": tasks[i].get("error_msg"), - } - else: - info = { - "request_id": tasks[i].request_id, - "device_ids": [self.cfg.parallel_config.device_ids.split(",")[self.local_data_parallel_id]], - "ip": self.cfg.host_ip, - "rdma_ports": [ - self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"][self.local_data_parallel_id] - ], - "transfer_protocol": "rdma", - "dest_block_ids": dsg_info["block_tables"], - "decode_tp_size": self.cfg.parallel_config.tensor_parallel_size, - } - - addr = f"{dsg_info['cache_info']['rdma']['ip']}:" + f"{dsg_info['cache_info']['rdma']['port']}" if addr not in cache_info: cache_info[addr] = [] cache_info[addr].append(info) self.logger.debug(f"send cache info to prefill, {cache_info}") if len(cache_info): - for k, v in cache_info.items(): - self.logger.info(f"{k} {v}") - if ":" in str(k): - self._send_message(k, "cache_sync", v) - else: - if k not in self.connect_innode_instances: - self.create_connection(k) - self.connect_innode_instances[k].put_cache_info(v) + for key, info in cache_info.items(): + self._send_message(key, "cache_sync", info) def _serialize_message(self, msg_type: str, payload) -> bytes: # TODO 压缩 - if msg_type == "decode" or msg_type == "prefill": payload = [output.to_dict() for output in payload] diff --git a/tests/e2e/test_ernie_03b_pd_router_v0.py b/tests/e2e/test_ernie_03b_pd_router_v0_ipc.py similarity index 100% rename from tests/e2e/test_ernie_03b_pd_router_v0.py rename to tests/e2e/test_ernie_03b_pd_router_v0_ipc.py diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index 82f06ef0e..7638e465b 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -130,7 +130,7 @@ class TestConfig(unittest.TestCase): test_mode=True, ) fd_config.init_cache_info() - assert fd_config.disaggregate_info["role"] == "prefill" + assert fd_config.register_info is not None if __name__ == "__main__":