[PD Disaggregation] support DP via v1 router and decouple DP and EP (#5197)

* [fix] support DP via v1 router and decouple DP and EP

* [fix] fix scripts

* [fix] reset model path

* [fix] dp use get_output_ep, fix router port type, update scripts

* [merge] merge with latest code

* [chore] remove some debug log

* [fix] fix code style check

* [fix] fix test_multi_api_server for log_dir name

* [chore] reduce logs

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yonghua Li
2025-12-04 15:38:43 +08:00
committed by GitHub
parent 5cd17fd662
commit f4119d51b4
15 changed files with 394 additions and 146 deletions

View File

@@ -41,7 +41,7 @@ from fastdeploy.inter_communicator import (
)
from fastdeploy.utils import envs, get_logger
logger = get_logger("cache_messager", "cache_messager.log")
# logger = get_logger("cache_messager", "cache_messager.log")
def parse_args():
@@ -552,6 +552,7 @@ class CacheMessagerV1:
cache_info = self.engine_worker_queue.get_cache_info()
finished_add_cache_task_req_ids = []
if cache_info:
logger.debug(f"Get cache info from engine worker queue, {cache_info}")
self.engine_worker_queue.cache_info_barrier.wait()
for info in cache_info:
if info["request_id"] in self.cache_info:
@@ -570,14 +571,15 @@ class CacheMessagerV1:
current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init"
logger.info(f"Get cache info from P: finish add cache task: {current_info}")
logger.info(f"Get cache info from D: finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
else:
logger.info(f"Get cache info from D: {info}")
logger.info(f"Get cache info from P: {info}")
self.cache_info[info["request_id"]] = info
if finished_add_cache_task_req_ids:
logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}")
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
else:
@@ -671,7 +673,7 @@ class CacheMessagerV1:
target_ip, target_id, decode_tp_size
)
if status:
logger.info(f"connect to {target_ip}:{target_id} success")
logger.debug(f"connect to {target_ip}:{target_id} success")
else:
logger.error(f"connect to {target_ip}:{target_id} failed")
task["status"] = "connection error"
@@ -722,7 +724,7 @@ class CacheMessagerV1:
if "error" not in task["status"]:
task["status"] = "finished"
logger.info(
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
f"Finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
)
else:
task["sended_layer_id"] = -1
@@ -736,7 +738,9 @@ class CacheMessagerV1:
self.messager["ipc"].write_block_by_sync(target_id)
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
logger.info(
f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}"
)
self.engine_cache_tasks[task["current_id"]] = dict()
del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]]
@@ -928,7 +932,8 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log")
logger.info("create cache messager...")
logger.info(f"{args}")
main()