[Optimize] Support and robust for tpN for PD (#4595)

* [Optimize] Support and robust for tpN for PD

* fix

* fix

* support dpM tpN for cache messager

* fix

* fix token counter

* fix bug for merge develop

* fix bug

* robust cache messager for v0
This commit is contained in:
chenjian
2025-11-03 15:38:31 +08:00
committed by GitHub
parent 7b35488779
commit 25498efcf3
9 changed files with 452 additions and 197 deletions

View File

@@ -246,11 +246,10 @@ class CacheMessager:
engine_recycled_count = 0
while True:
cache_info = self.engine_worker_queue.get_cache_info()
if cache_info:
logger.debug(f"cache info {cache_info}")
self.engine_worker_queue.cache_info_barrier.wait()
for info in cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
@@ -295,9 +294,6 @@ class CacheMessager:
continue
if "layer_idx" not in item:
item["layer_idx"] = 0
if item["status"] == "error":
del self.cache_info[req_id]
continue
if item["current_id"] > prefilled_step_idx:
continue
current_transfer_protocol = item["transfer_protocol"]
@@ -307,11 +303,7 @@ class CacheMessager:
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status:
logger.error(f"connect to {target_ip}:{target_id} failed")
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
continue
item["status"] = "connect error"
elif item["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(item["device_ids"][self.rank])
@@ -321,48 +313,43 @@ class CacheMessager:
current_layer_idx = self.num_layers
else:
current_layer_idx = prefilled_layer_idx + 1
for layer_idx in range(item["layer_idx"], current_layer_idx):
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
src_block_ids,
dest_block_ids,
layer_idx,
)
if return_code != 0:
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
logger.error(
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
if "error" not in item["status"]:
for layer_idx in range(item["layer_idx"], current_layer_idx):
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
src_block_ids,
dest_block_ids,
layer_idx,
)
break
if return_code != 0:
item["status"] = "write cache error"
logger.error(
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
)
break
tok = time.time()
cost_time = tok - tic
block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
f" {current_transfer_protocol}"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
)
tok = time.time()
cost_time = tok - tic
block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
f" {current_transfer_protocol}"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
)
item["layer_idx"] = current_layer_idx
if item["layer_idx"] == self.num_layers:
if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}")
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[item["request_id"], item["status"]]])
logger.info(f"put write cache {item['request_id']}, status {item['status']}")
del self.cache_info[req_id]
self.last_layer_idx = prefilled_layer_idx
@@ -376,14 +363,17 @@ class CacheMessager:
if task is None:
time.sleep(0.001)
continue
else:
self.engine_worker_queue.connect_task_barrier.wait()
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.connect_task_response_barrier.wait()
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")
@@ -524,9 +514,9 @@ class CacheMessagerV1:
while True:
try:
cache_info = self.engine_worker_queue.get_cache_info()
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
finished_add_cache_task_req_ids = []
if cache_info:
self.engine_worker_queue.cache_info_barrier.wait()
for info in cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
@@ -544,13 +534,16 @@ class CacheMessagerV1:
current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init"
logger.info(f"finish add cache task: {current_info}")
logger.info(f"Get cache info from P: finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
else:
logger.info(f"Get cache info from D: {info}")
self.cache_info[info["request_id"]] = info
if self.rank == 0 and finished_add_cache_task_req_ids:
if finished_add_cache_task_req_ids:
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
else:
time.sleep(0.001)
except Exception as e:
@@ -563,14 +556,16 @@ class CacheMessagerV1:
"""
while True:
try:
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
self.engine_worker_queue.finish_request_barrier.wait()
batch_engine_signals = self.cache_prefilled_engine_ids_queue.get()
self.engine_worker_queue.begin_send_cache_barrier.wait()
block_start_end_list = []
current_prefilled_token_num_list = []
for engine_index in engine_indexes:
assert engine_index in self.idx_cache_task_dict
for engine_index, current_step_prefilled_token_num in batch_engine_signals:
assert (
engine_index in self.idx_cache_task_dict
), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}"
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
prefilled_token_num = current_step_prefilled_token_num
if (
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
): # all chunks have been prefilled
@@ -580,17 +575,20 @@ class CacheMessagerV1:
block_start_end_list.append((block_id_start, block_id_end))
current_prefilled_token_num_list.append(prefilled_token_num)
while True: # from layer0 to last layer
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
start_layer_idx = sended_layer_idx + 1
with self.engine_cache_task_thread_lock: # to check end_layer_idx
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
prefilled_layer_idx = self.engine_cache_tasks[batch_engine_signals[0][0]][
"prefilled_layer_idx"
]
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
logger.info(
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[batch_engine_signals[0][0]]['prefilled_token_num']}"
)
assert (
current_prefilled_token_num_list[0]
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
< self.engine_cache_tasks[batch_engine_signals[0][0]]["prefilled_token_num"]
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
else:
@@ -599,7 +597,7 @@ class CacheMessagerV1:
time.sleep(0.01)
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
engine_index = engine_indexes[i]
engine_index = batch_engine_signals[i][0]
task = self.idx_cache_task_dict[engine_index]
req_id = task["request_id"]
if (
@@ -615,7 +613,7 @@ class CacheMessagerV1:
if task["transfer_protocol"] == "rdma":
target_ip = task["ip"]
target_id = int(task["rdma_ports"][self.rank])
if task["status"] == "error":
if "error" in task["status"]:
continue
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status:
@@ -665,7 +663,7 @@ class CacheMessagerV1:
block_id_end - block_id_start
)
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
if task["status"] != "error":
if "error" not in task["status"]:
task["status"] = "finished"
logger.info(
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
@@ -674,18 +672,15 @@ class CacheMessagerV1:
task["sended_layer_id"] = -1
if end_layer_idx == self.num_layers - 1:
with self.engine_cache_task_thread_lock:
for engine_idx in engine_indexes:
for engine_idx, _ in batch_engine_signals:
task = self.idx_cache_task_dict[engine_idx]
if task["status"] == "finished" or ("error" in task["status"]):
target_id = int(task["rdma_ports"][self.rank])
if task["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
if self.rank == 0:
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
self.engine_worker_queue.put_finished_req(
[(task["request_id"], task["status"])]
)
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
self.engine_cache_tasks[task["current_id"]] = dict()
del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]]
@@ -709,8 +704,9 @@ class CacheMessagerV1:
continue
layer_id = kv_signal_data[1].numpy().tolist()
if layer_id == self.num_layers - 1:
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
batch_engine_ids = []
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}")
batch_engine_signals = []
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
with self.engine_cache_task_thread_lock:
for bi in range(tasks_count):
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
@@ -720,27 +716,33 @@ class CacheMessagerV1:
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
chuck_token_offset + current_seq_len
)
batch_engine_ids.append(engine_idx)
batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len))
if layer_id == 0:
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
logger.info(
f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue"
)
self.cache_prefilled_engine_ids_queue.put(batch_engine_signals)
except Exception as e:
logger.error(f"Consume signals get exception: {e}")
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
task, _ = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
else:
self.engine_worker_queue.connect_task_barrier.wait()
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.connect_task_response_barrier.wait()
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")

View File

@@ -310,11 +310,7 @@ class EngineService:
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=0,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
local_data_parallel_id=min(
self.cfg.worker_num_per_node // self.cfg.parallel_config.tensor_parallel_size * self.cfg.node_rank
+ self.cfg.parallel_config.local_data_parallel_id,
self.cfg.parallel_config.data_parallel_size - 1,
),
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
)
def insert_tasks(self, tasks, current_id=-1, allocated=False):
@@ -656,39 +652,60 @@ class EngineService:
self.cfg.max_prefill_batch,
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
else:
max_num_batched_tokens = self.cfg.model_config.max_model_len
tasks = self.scheduler.get_requests(
available_blocks=self.cfg.cache_config.max_block_num_per_seq,
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.model_config.max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
while True:
self.split_connector.send_splitwise_tasks([task], task.idx)
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.")
time.sleep(0.05)
else:
break
else:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
@@ -930,7 +947,7 @@ class EngineService:
for request_id, contents in results.items():
new_contents = []
for content in contents:
if isinstance(content, RequestOutput):
if isinstance(content, RequestOutput) and content.outputs is not None:
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
@@ -1035,6 +1052,7 @@ class EngineService:
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
self.resource_manager.insert_task_for_decoding(task)
else:

View File

@@ -27,6 +27,7 @@ import numpy as np
from fastdeploy.engine.common_engine import EngineService
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.utils import console_logger, envs, llm_logger
@@ -99,6 +100,10 @@ class ExpertService:
self.engine.start_zmq_service(ipc_signal_suffix)
else:
ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.internal_adapter = InternalAdapter(
cfg=self.cfg, engine=self.engine, dp_rank=self.cfg.parallel_config.local_data_parallel_id
)
llm_logger.info(f"start expert service {local_data_parallel_id}")

View File

@@ -151,6 +151,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_MOE_QUANT_TYPE": lambda: os.getenv("FD_MOE_QUANT_TYPE", "w4a8"),
"ENCODE_FEATURE_BOS_AK": lambda: os.getenv("ENCODE_FEATURE_BOS_AK"),
"ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"),
# Enable offline perf test mode for PD disaggregation
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
}

View File

@@ -80,24 +80,79 @@ class EngineWorkerQueue:
self.client_read_flag_init: List[List[int]] = [
[1] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)]
self.read_finish_flag_init: List[Value] = [Value("i", 0) for _ in range(self.local_data_parallel_size)]
self.connected_client_counter_init: List[Value] = [
Value("i", 0) for _ in range(self.local_data_parallel_size)
]
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)]
self.finished_add_cache_task_list = [list() for _ in range(self.local_data_parallel_size)]
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
self.client_read_info_flag_init: List[List[int]] = [
[1] * self.num_client for _ in range(self.local_data_parallel_size)
[0] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.lock_info_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
# PD disaggregation
# Locks
self.connect_task_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
] # connect rdma task
self.connect_task_response_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
] # connect rdma task response
self.finish_add_cache_task_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
] # finish add cache task
self.finish_send_cache_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
] # finish send cache
# sync read status for TPs
self.client_get_connect_task_flag_init: List[List[int]] = [
[0] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.client_get_connect_task_response_flag_init: List[List[int]] = [
[0] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.client_get_finished_add_cache_task_flag_init: List[List[int]] = [
[0] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.client_get_finish_send_cache_flag_init: List[List[int]] = [
[0] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.can_put_next_connect_task_response_flag_init: List[Value] = [
Value("i", 1) for _ in range(self.local_data_parallel_size)
]
self.can_put_next_add_task_finished_flag_init: List[Value] = [
Value("i", 1) for _ in range(self.local_data_parallel_size)
]
self.can_put_next_send_cache_finished_flag_init: List[Value] = [
Value("i", 1) for _ in range(self.local_data_parallel_size)
]
# barrier
self.get_connect_task_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.get_connect_task_response_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.finish_add_cache_task_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.begin_send_cache_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.finish_send_cache_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.get_cache_info_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.finish_request_barrier = [
@@ -107,10 +162,6 @@ class EngineWorkerQueue:
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.finish_add_cache_task_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
# Register shared objects with proxy types
QueueManager.register(
"get_tasks",
@@ -122,6 +173,26 @@ class EngineWorkerQueue:
callable=lambda idx: self.client_read_flag_init[idx],
proxytype=ListProxy,
)
QueueManager.register(
"get_client_get_connect_task_flag",
callable=lambda idx: self.client_get_connect_task_flag_init[idx],
proxytype=ListProxy,
)
QueueManager.register(
"get_client_get_connect_task_response_flag",
callable=lambda idx: self.client_get_connect_task_response_flag_init[idx],
proxytype=ListProxy,
)
QueueManager.register(
"get_client_get_finished_add_cache_task_flag_init",
callable=lambda idx: self.client_get_finished_add_cache_task_flag_init[idx],
proxytype=ListProxy,
)
QueueManager.register(
"get_client_get_finish_send_cache_flag_init",
callable=lambda idx: self.client_get_finish_send_cache_flag_init[idx],
proxytype=ListProxy,
)
QueueManager.register(
"get_lock",
callable=lambda idx: self.lock_init[idx],
@@ -132,11 +203,43 @@ class EngineWorkerQueue:
callable=lambda idx: self.read_finish_flag_init[idx],
proxytype=ValueProxy,
)
QueueManager.register(
"get_can_put_next_connect_task_response_flag",
callable=lambda idx: self.can_put_next_connect_task_response_flag_init[idx],
proxytype=ValueProxy,
)
QueueManager.register(
"get_can_put_next_add_task_finished_flag",
callable=lambda idx: self.can_put_next_add_task_finished_flag_init[idx],
proxytype=ValueProxy,
)
QueueManager.register(
"get_can_put_next_send_cache_finished_flag",
callable=lambda idx: self.can_put_next_send_cache_finished_flag_init[idx],
proxytype=ValueProxy,
)
# PD disaggregation
QueueManager.register(
"get_connect_task_lock",
callable=lambda idx: self.connect_task_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_connect_task_response_lock",
callable=lambda idx: self.connect_task_response_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_finish_add_cache_task_lock",
callable=lambda idx: self.finish_add_cache_task_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_finish_send_cache_lock",
callable=lambda idx: self.finish_send_cache_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy
)
@@ -152,13 +255,13 @@ class EngineWorkerQueue:
)
QueueManager.register(
"get_finish_request_queue",
callable=lambda idx: self.finished_req_queue[idx],
"get_finish_request_queue", callable=lambda idx: self.finished_req_list[idx], proxytype=ListProxy
)
QueueManager.register(
"get_finish_add_cache_task_queue",
callable=lambda idx: self.finished_add_cache_task_queue[idx],
callable=lambda idx: self.finished_add_cache_task_list[idx],
proxytype=ListProxy,
)
QueueManager.register(
@@ -194,6 +297,26 @@ class EngineWorkerQueue:
"get_finish_request_barrier",
callable=lambda idx: self.finish_request_barrier[idx],
)
QueueManager.register(
"get_connect_task_barrier",
callable=lambda idx: self.get_connect_task_barrier[idx],
)
QueueManager.register(
"get_connect_task_response_barrier",
callable=lambda idx: self.get_connect_task_response_barrier[idx],
)
QueueManager.register(
"get_begin_send_cache_barrier",
callable=lambda idx: self.begin_send_cache_barrier[idx],
)
QueueManager.register(
"get_finish_send_cache_barrier",
callable=lambda idx: self.finish_send_cache_barrier[idx],
)
QueueManager.register(
"get_cache_info_barrier",
callable=lambda idx: self.get_cache_info_barrier[idx],
)
QueueManager.register(
"get_finish_add_cache_task_barrier",
@@ -231,10 +354,25 @@ class EngineWorkerQueue:
QueueManager.register("get_available_prefill_instances")
QueueManager.register("get_finish_request_barrier")
QueueManager.register("get_finish_add_cache_task_barrier")
QueueManager.register("get_connect_task_barrier")
QueueManager.register("get_connect_task_response_barrier")
QueueManager.register("get_finish_send_cache_barrier")
QueueManager.register("get_begin_send_cache_barrier")
QueueManager.register("get_cache_info_barrier")
QueueManager.register("get_connect_rdma_tasks")
QueueManager.register("get_client_get_connect_task_flag")
QueueManager.register("get_client_get_connect_task_response_flag")
QueueManager.register("get_client_get_finished_add_cache_task_flag_init")
QueueManager.register("get_client_get_finish_send_cache_flag_init")
QueueManager.register("get_connect_rdma_tasks_responses")
QueueManager.register("get_connect_task_lock")
QueueManager.register("get_connect_task_response_lock")
QueueManager.register("get_finish_add_cache_task_lock")
QueueManager.register("get_finish_send_cache_lock")
QueueManager.register("get_worker_process_tp_barrier")
QueueManager.register("get_can_put_next_connect_task_response_flag")
QueueManager.register("get_can_put_next_add_task_finished_flag")
QueueManager.register("get_can_put_next_send_cache_finished_flag")
self.manager = QueueManager(address=self.address, authkey=self.authkey)
self._connect_with_retry()
@@ -257,17 +395,50 @@ class EngineWorkerQueue:
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
self.local_data_parallel_id
)
self.connect_task_barrier = self.manager.get_connect_task_barrier(self.local_data_parallel_id)
self.connect_task_response_barrier = self.manager.get_connect_task_response_barrier(
self.local_data_parallel_id
)
self.finish_send_cache_barrier = self.manager.get_finish_send_cache_barrier(self.local_data_parallel_id)
self.cache_info_barrier = self.manager.get_cache_info_barrier(self.local_data_parallel_id)
self.begin_send_cache_barrier = self.manager.get_begin_send_cache_barrier(self.local_data_parallel_id)
self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id)
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
self.finished_send_cache_list = self.manager.get_finish_request_queue(self.local_data_parallel_id)
self.finished_add_cache_task_list = self.manager.get_finish_add_cache_task_queue(
self.local_data_parallel_id
)
# p/d互联
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
self.connect_rdma_tasks = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
self.client_get_connect_task_flag = self.manager.get_client_get_connect_task_flag(
self.local_data_parallel_id
)
self.client_get_connect_task_response_flag = self.manager.get_client_get_connect_task_response_flag(
self.local_data_parallel_id
)
self.client_get_finished_add_cache_task_flag = (
self.manager.get_client_get_finished_add_cache_task_flag_init(self.local_data_parallel_id)
)
self.client_get_finish_send_cache_flag = self.manager.get_client_get_finish_send_cache_flag_init(
self.local_data_parallel_id
)
self.connect_rdma_task_responses = self.manager.get_connect_rdma_tasks_responses(
self.local_data_parallel_id
)
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
self.connect_task_response_lock = self.manager.get_connect_task_response_lock(self.local_data_parallel_id)
self.finish_add_cache_task_lock = self.manager.get_finish_add_cache_task_lock(self.local_data_parallel_id)
self.finish_send_cache_lock = self.manager.get_finish_send_cache_lock(self.local_data_parallel_id)
self.can_put_next_add_task_finished_flag = self.manager.get_can_put_next_add_task_finished_flag(
self.local_data_parallel_id
)
self.can_put_next_connect_task_response_flag = self.manager.get_can_put_next_connect_task_response_flag(
self.local_data_parallel_id
)
self.can_put_next_send_cache_finished_flag = self.manager.get_can_put_next_send_cache_finished_flag(
self.local_data_parallel_id
)
assert self.num_client == len(self.client_read_flag)
@@ -411,41 +582,61 @@ class EngineWorkerQueue:
def put_connect_rdma_task(self, connect_rdma_task):
self.connect_task_lock.acquire()
self.connect_rdma_task_queue.append(connect_rdma_task)
while sum(self.client_get_connect_task_flag) < self.num_client:
self.connect_task_lock.release()
time.sleep(0.001)
self.connect_task_lock.acquire()
self.connect_rdma_tasks[:] = list()
self.client_get_connect_task_flag[:] = [0] * self.num_client
self.connect_rdma_tasks.append(connect_rdma_task)
self.connect_task_lock.release()
def get_connect_rdma_task(self):
result = None
connect_rdma_task = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
finally:
self.connect_task_lock.release()
return result
if len(self.connect_rdma_tasks) > 0:
connect_rdma_task = self.connect_rdma_tasks[0]
self.client_get_connect_task_flag[self.client_id] = 1
all_client_read: bool = np.sum(self.client_get_connect_task_flag) == self.num_client
if all_client_read:
self.connect_rdma_tasks[:] = list()
self.connect_task_lock.release()
return connect_rdma_task, all_client_read
def put_connect_rdma_task_response(self, connect_rdma_task_response):
self.connect_task_lock.acquire()
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
self.connect_task_lock.release()
self.connect_task_response_lock.acquire()
while not self.can_put_next_connect_task_response_flag.get():
self.connect_task_response_lock.release()
time.sleep(0.001)
self.connect_task_response_lock.acquire()
self.connect_rdma_task_responses.append(connect_rdma_task_response)
self.client_get_connect_task_response_flag[self.client_id] = 1
all_client_put: bool = np.sum(self.client_get_connect_task_response_flag) == self.num_client
if all_client_put:
self.can_put_next_connect_task_response_flag.set(0)
self.connect_task_response_lock.release()
return all_client_put
def get_connect_rdma_task_response(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_response_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_response_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
finally:
self.connect_task_lock.release()
return result
task_response = None
self.connect_task_response_lock.acquire()
if len(self.connect_rdma_task_responses) == 0:
self.connect_task_response_lock.release()
return task_response
while sum(self.client_get_connect_task_response_flag) < self.num_client:
self.connect_task_response_lock.release()
time.sleep(0.001)
self.connect_task_response_lock.acquire()
if len(self.connect_rdma_task_responses) > 0:
task_response = self.connect_rdma_task_responses[0]
for tmp_task_response in self.connect_rdma_task_responses:
task_response["success"] = task_response["success"] and tmp_task_response["success"]
self.connect_rdma_task_responses[:] = list()
self.client_get_connect_task_response_flag[:] = [0] * self.num_client
self.can_put_next_connect_task_response_flag.set(1)
self.connect_task_response_lock.release()
return task_response
def get_prefill_instances(self):
"""
@@ -508,14 +699,25 @@ class EngineWorkerQueue:
self.lock_info.release()
return total_num
def put_finished_req(self, req_ids) -> None:
def put_finished_req(self, send_cache_result) -> None:
"""
Put finished request ID into the queue.
Args:
req_ids: Request ID to be added to the queue
"""
self.finished_req_queue.put(req_ids)
self.finish_send_cache_lock.acquire()
while not self.can_put_next_send_cache_finished_flag.get():
self.finish_send_cache_lock.release()
time.sleep(0.001)
self.finish_send_cache_lock.acquire()
self.finished_send_cache_list.append(send_cache_result[0])
self.client_get_finish_send_cache_flag[self.client_id] = 1
all_client_put: bool = np.sum(self.client_get_finish_send_cache_flag) == self.num_client
if all_client_put:
self.can_put_next_send_cache_finished_flag.set(0)
self.finish_send_cache_lock.release()
return all_client_put
def get_finished_req(self) -> str:
"""
@@ -524,12 +726,27 @@ class EngineWorkerQueue:
Returns:
str: Finished request ID
"""
ans = []
if self.finished_req_queue.empty():
return ans
ans = self.finished_req_queue.get()
llm_logger.debug(f"get finished req: {ans}")
return ans
response = []
self.finish_send_cache_lock.acquire()
if len(self.finished_send_cache_list) == 0:
self.finish_send_cache_lock.release()
return response
while sum(self.client_get_finish_send_cache_flag) < self.num_client:
self.finish_send_cache_lock.release()
time.sleep(0.001)
self.finish_send_cache_lock.acquire()
if len(self.finished_send_cache_list) > 0:
response = self.finished_send_cache_list[0]
for tmp_response in self.finished_send_cache_list:
if "error" in tmp_response[1]:
response[1] = tmp_response[1]
if response:
response = [response]
self.finished_send_cache_list[:] = list()
self.client_get_finish_send_cache_flag[:] = [0] * self.num_client
self.can_put_next_send_cache_finished_flag.set(1)
self.finish_send_cache_lock.release()
return response
def put_finished_add_cache_task_req(self, req_ids) -> None:
"""
@@ -538,7 +755,18 @@ class EngineWorkerQueue:
Args:
req_ids: Request ID to be added to the queue
"""
self.finished_add_cache_task_queue.put(req_ids)
self.finish_add_cache_task_lock.acquire()
while not self.can_put_next_add_task_finished_flag.get():
self.finish_add_cache_task_lock.release()
time.sleep(0.001)
self.finish_add_cache_task_lock.acquire()
self.finished_add_cache_task_list.append(req_ids)
self.client_get_finished_add_cache_task_flag[self.client_id] = 1
all_client_put: bool = np.sum(self.client_get_finished_add_cache_task_flag) == self.num_client
if all_client_put:
self.can_put_next_add_task_finished_flag.set(0)
self.finish_add_cache_task_lock.release()
return all_client_put
def get_finished_add_cache_task_req(self) -> str:
"""
@@ -547,12 +775,24 @@ class EngineWorkerQueue:
Returns:
str: Finished request ID
"""
ans = []
if self.finished_add_cache_task_queue.empty():
return ans
ans = self.finished_add_cache_task_queue.get()
llm_logger.debug(f"get finished req: {ans}")
return ans
response = []
self.finish_add_cache_task_lock.acquire()
if len(self.finished_add_cache_task_list) == 0:
self.finish_add_cache_task_lock.release()
return response
while sum(self.client_get_finished_add_cache_task_flag) < self.num_client:
self.finish_add_cache_task_lock.release()
time.sleep(0.001)
self.finish_add_cache_task_lock.acquire()
if len(self.finished_add_cache_task_list) > 0:
response = self.finished_add_cache_task_list[0]
for tmp_response in self.finished_add_cache_task_list:
assert tmp_response == response
self.finished_add_cache_task_list[:] = list()
self.client_get_finished_add_cache_task_flag[:] = [0] * self.num_client
self.can_put_next_add_task_finished_flag.set(1)
self.finish_add_cache_task_lock.release()
return response
def disaggregate_queue_empty(self):
"""

View File

@@ -143,38 +143,7 @@ class DPLocalScheduler(LocalScheduler):
requests: List[Request] = []
with self.requests_not_empty:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(
request.prompt_tokens_ids_len, block_size
)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
else:
required_total_blocks = 0
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
@@ -183,11 +152,24 @@ class DPLocalScheduler(LocalScheduler):
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
if batch_ids:
if len(batch_ids) > 0 and len(requests) == 0:

View File

@@ -78,7 +78,7 @@ class InternalAdapter:
if task is None:
time.sleep(0.001)
continue
logger.info(f"Recieve control task: {task}")
logger.info(f"dprank {self.dp_rank} Recieve control task: {task}")
task_id_str = task["task_id"]
if task["cmd"] == "get_payload":
payload_info = self._get_current_server_info()

View File

@@ -275,6 +275,7 @@ class SplitwiseConnector:
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
task.disaggregate_info["role"] = "decode"
self._send_message(addr, "prefill", [task])
task.disaggregate_info["cache_info"] = decode_diagg
task.disaggregate_info["role"] = "prefill"

View File

@@ -177,7 +177,7 @@ class PaddleDisWorkerProc:
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
launched_expert_service_signal_data = np.zeros(
shape=[min(self.parallel_config.data_parallel_size, self.max_chips_per_node)], dtype=np.int32
shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
@@ -186,7 +186,12 @@ class PaddleDisWorkerProc:
suffix=self.parallel_config.engine_pid,
create=False,
)
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
while (
self.launched_expert_service_signal.value[
self.parallel_config.local_data_parallel_id % self.max_chips_per_node
]
== 0
):
pass
# init worker_ready_signal
@@ -568,7 +573,7 @@ class PaddleDisWorkerProc:
is_server=False,
num_client=self.parallel_config.tensor_parallel_size,
client_id=self.parallel_config.tensor_parallel_rank,
local_data_parallel_id=self.parallel_config.data_parallel_rank,
local_data_parallel_id=self.parallel_config.local_data_parallel_id,
)
def load_model(self) -> None: