[PD Disaggregation] support DP via v1 router and decouple DP and EP (#5197)

* [fix] support DP via v1 router and decouple DP and EP

* [fix] fix scripts

* [fix] reset model path

* [fix] dp use get_output_ep, fix router port type, update scripts

* [merge] merge with latest code

* [chore] remove some debug log

* [fix] fix code style check

* [fix] fix test_multi_api_server for log_dir name

* [chore] reduce logs

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yonghua Li
2025-12-04 15:38:43 +08:00
committed by GitHub
parent 5cd17fd662
commit f4119d51b4
15 changed files with 394 additions and 146 deletions

View File

@@ -41,7 +41,7 @@ from fastdeploy.inter_communicator import (
)
from fastdeploy.utils import envs, get_logger
logger = get_logger("cache_messager", "cache_messager.log")
# logger = get_logger("cache_messager", "cache_messager.log")
def parse_args():
@@ -552,6 +552,7 @@ class CacheMessagerV1:
cache_info = self.engine_worker_queue.get_cache_info()
finished_add_cache_task_req_ids = []
if cache_info:
logger.debug(f"Get cache info from engine worker queue, {cache_info}")
self.engine_worker_queue.cache_info_barrier.wait()
for info in cache_info:
if info["request_id"] in self.cache_info:
@@ -570,14 +571,15 @@ class CacheMessagerV1:
current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init"
logger.info(f"Get cache info from P: finish add cache task: {current_info}")
logger.info(f"Get cache info from D: finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
else:
logger.info(f"Get cache info from D: {info}")
logger.info(f"Get cache info from P: {info}")
self.cache_info[info["request_id"]] = info
if finished_add_cache_task_req_ids:
logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}")
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
else:
@@ -671,7 +673,7 @@ class CacheMessagerV1:
target_ip, target_id, decode_tp_size
)
if status:
logger.info(f"connect to {target_ip}:{target_id} success")
logger.debug(f"connect to {target_ip}:{target_id} success")
else:
logger.error(f"connect to {target_ip}:{target_id} failed")
task["status"] = "connection error"
@@ -722,7 +724,7 @@ class CacheMessagerV1:
if "error" not in task["status"]:
task["status"] = "finished"
logger.info(
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
f"Finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
)
else:
task["sended_layer_id"] = -1
@@ -736,7 +738,9 @@ class CacheMessagerV1:
self.messager["ipc"].write_block_by_sync(target_id)
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
logger.info(
f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}"
)
self.engine_cache_tasks[task["current_id"]] = dict()
del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]]
@@ -928,7 +932,8 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log")
logger.info("create cache messager...")
logger.info(f"{args}")
main()

View File

@@ -740,6 +740,6 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
set_device(args.device_id)
main()

View File

@@ -280,7 +280,7 @@ class PrefixCacheManager:
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
@@ -372,7 +372,7 @@ class PrefixCacheManager:
+ f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
)
logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))