mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Optimize] Support and robust for tpN for PD (#4595)
* [Optimize] Support and robust for tpN for PD * fix * fix * support dpM tpN for cache messager * fix * fix token counter * fix bug for merge develop * fix bug * robust cache messager for v0
This commit is contained in:
@@ -246,11 +246,10 @@ class CacheMessager:
|
||||
engine_recycled_count = 0
|
||||
|
||||
while True:
|
||||
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
|
||||
if cache_info:
|
||||
logger.debug(f"cache info {cache_info}")
|
||||
self.engine_worker_queue.cache_info_barrier.wait()
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
@@ -295,9 +294,6 @@ class CacheMessager:
|
||||
continue
|
||||
if "layer_idx" not in item:
|
||||
item["layer_idx"] = 0
|
||||
if item["status"] == "error":
|
||||
del self.cache_info[req_id]
|
||||
continue
|
||||
if item["current_id"] > prefilled_step_idx:
|
||||
continue
|
||||
current_transfer_protocol = item["transfer_protocol"]
|
||||
@@ -307,11 +303,7 @@ class CacheMessager:
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
|
||||
continue
|
||||
item["status"] = "connect error"
|
||||
elif item["transfer_protocol"] == "ipc":
|
||||
target_ip = "0.0.0.0"
|
||||
target_id = int(item["device_ids"][self.rank])
|
||||
@@ -321,48 +313,43 @@ class CacheMessager:
|
||||
current_layer_idx = self.num_layers
|
||||
else:
|
||||
current_layer_idx = prefilled_layer_idx + 1
|
||||
|
||||
for layer_idx in range(item["layer_idx"], current_layer_idx):
|
||||
tic = time.time()
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
if return_code != 0:
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
if "error" not in item["status"]:
|
||||
for layer_idx in range(item["layer_idx"], current_layer_idx):
|
||||
tic = time.time()
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
break
|
||||
if return_code != 0:
|
||||
item["status"] = "write cache error"
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
)
|
||||
break
|
||||
|
||||
tok = time.time()
|
||||
cost_time = tok - tic
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
||||
f" {current_transfer_protocol}"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||
)
|
||||
tok = time.time()
|
||||
cost_time = tok - tic
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
||||
f" {current_transfer_protocol}"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||
)
|
||||
item["layer_idx"] = current_layer_idx
|
||||
if item["layer_idx"] == self.num_layers:
|
||||
if item["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
||||
logger.info(f"put write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_send_cache_barrier.wait()
|
||||
self.engine_worker_queue.put_finished_req([[item["request_id"], item["status"]]])
|
||||
logger.info(f"put write cache {item['request_id']}, status {item['status']}")
|
||||
del self.cache_info[req_id]
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
|
||||
@@ -376,14 +363,17 @@ class CacheMessager:
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
@@ -524,9 +514,9 @@ class CacheMessagerV1:
|
||||
while True:
|
||||
try:
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
||||
finished_add_cache_task_req_ids = []
|
||||
if cache_info:
|
||||
self.engine_worker_queue.cache_info_barrier.wait()
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
@@ -544,13 +534,16 @@ class CacheMessagerV1:
|
||||
current_info["sended_layer_id"] = -1
|
||||
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
|
||||
current_info["status"] = "init"
|
||||
logger.info(f"finish add cache task: {current_info}")
|
||||
logger.info(f"Get cache info from P: finish add cache task: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.idx_cache_task_dict[current_info["current_id"]] = current_info
|
||||
else:
|
||||
logger.info(f"Get cache info from D: {info}")
|
||||
self.cache_info[info["request_id"]] = info
|
||||
if self.rank == 0 and finished_add_cache_task_req_ids:
|
||||
|
||||
if finished_add_cache_task_req_ids:
|
||||
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
|
||||
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
@@ -563,14 +556,16 @@ class CacheMessagerV1:
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
batch_engine_signals = self.cache_prefilled_engine_ids_queue.get()
|
||||
self.engine_worker_queue.begin_send_cache_barrier.wait()
|
||||
block_start_end_list = []
|
||||
current_prefilled_token_num_list = []
|
||||
for engine_index in engine_indexes:
|
||||
assert engine_index in self.idx_cache_task_dict
|
||||
for engine_index, current_step_prefilled_token_num in batch_engine_signals:
|
||||
assert (
|
||||
engine_index in self.idx_cache_task_dict
|
||||
), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}"
|
||||
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
|
||||
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
|
||||
prefilled_token_num = current_step_prefilled_token_num
|
||||
if (
|
||||
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
|
||||
): # all chunks have been prefilled
|
||||
@@ -580,17 +575,20 @@ class CacheMessagerV1:
|
||||
block_start_end_list.append((block_id_start, block_id_end))
|
||||
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||
while True: # from layer0 to last layer
|
||||
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
|
||||
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
|
||||
start_layer_idx = sended_layer_idx + 1
|
||||
with self.engine_cache_task_thread_lock: # to check end_layer_idx
|
||||
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
|
||||
prefilled_layer_idx = self.engine_cache_tasks[batch_engine_signals[0][0]][
|
||||
"prefilled_layer_idx"
|
||||
]
|
||||
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
|
||||
logger.info(
|
||||
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
|
||||
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[batch_engine_signals[0][0]]['prefilled_token_num']}"
|
||||
)
|
||||
|
||||
assert (
|
||||
current_prefilled_token_num_list[0]
|
||||
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
|
||||
< self.engine_cache_tasks[batch_engine_signals[0][0]]["prefilled_token_num"]
|
||||
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
|
||||
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
|
||||
else:
|
||||
@@ -599,7 +597,7 @@ class CacheMessagerV1:
|
||||
time.sleep(0.01)
|
||||
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
|
||||
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
|
||||
engine_index = engine_indexes[i]
|
||||
engine_index = batch_engine_signals[i][0]
|
||||
task = self.idx_cache_task_dict[engine_index]
|
||||
req_id = task["request_id"]
|
||||
if (
|
||||
@@ -615,7 +613,7 @@ class CacheMessagerV1:
|
||||
if task["transfer_protocol"] == "rdma":
|
||||
target_ip = task["ip"]
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
if task["status"] == "error":
|
||||
if "error" in task["status"]:
|
||||
continue
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
@@ -665,7 +663,7 @@ class CacheMessagerV1:
|
||||
block_id_end - block_id_start
|
||||
)
|
||||
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
|
||||
if task["status"] != "error":
|
||||
if "error" not in task["status"]:
|
||||
task["status"] = "finished"
|
||||
logger.info(
|
||||
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
|
||||
@@ -674,18 +672,15 @@ class CacheMessagerV1:
|
||||
task["sended_layer_id"] = -1
|
||||
if end_layer_idx == self.num_layers - 1:
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for engine_idx in engine_indexes:
|
||||
for engine_idx, _ in batch_engine_signals:
|
||||
task = self.idx_cache_task_dict[engine_idx]
|
||||
if task["status"] == "finished" or ("error" in task["status"]):
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
if task["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
if self.rank == 0:
|
||||
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||
self.engine_worker_queue.put_finished_req(
|
||||
[(task["request_id"], task["status"])]
|
||||
)
|
||||
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
||||
self.engine_worker_queue.finish_send_cache_barrier.wait()
|
||||
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
|
||||
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
||||
self.engine_cache_tasks[task["current_id"]] = dict()
|
||||
del self.cache_info[task["request_id"]]
|
||||
del self.idx_cache_task_dict[task["current_id"]]
|
||||
@@ -709,8 +704,9 @@ class CacheMessagerV1:
|
||||
continue
|
||||
layer_id = kv_signal_data[1].numpy().tolist()
|
||||
if layer_id == self.num_layers - 1:
|
||||
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
|
||||
batch_engine_ids = []
|
||||
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}")
|
||||
batch_engine_signals = []
|
||||
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for bi in range(tasks_count):
|
||||
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
|
||||
@@ -720,27 +716,33 @@ class CacheMessagerV1:
|
||||
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
|
||||
chuck_token_offset + current_seq_len
|
||||
)
|
||||
batch_engine_ids.append(engine_idx)
|
||||
batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len))
|
||||
if layer_id == 0:
|
||||
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
|
||||
logger.info(
|
||||
f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue"
|
||||
)
|
||||
self.cache_prefilled_engine_ids_queue.put(batch_engine_signals)
|
||||
except Exception as e:
|
||||
logger.error(f"Consume signals get exception: {e}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
task, _ = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
||||
@@ -310,11 +310,7 @@ class EngineService:
|
||||
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||
client_id=0,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
local_data_parallel_id=min(
|
||||
self.cfg.worker_num_per_node // self.cfg.parallel_config.tensor_parallel_size * self.cfg.node_rank
|
||||
+ self.cfg.parallel_config.local_data_parallel_id,
|
||||
self.cfg.parallel_config.data_parallel_size - 1,
|
||||
),
|
||||
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
|
||||
)
|
||||
|
||||
def insert_tasks(self, tasks, current_id=-1, allocated=False):
|
||||
@@ -656,39 +652,60 @@ class EngineService:
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
|
||||
else:
|
||||
max_num_batched_tokens = self.cfg.model_config.max_model_len
|
||||
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=self.cfg.cache_config.max_block_num_per_seq,
|
||||
block_size=self.cfg.cache_config.block_size,
|
||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||
max_num_batched_tokens=self.cfg.model_config.max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
for task in tasks:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
need_delete_tasks = []
|
||||
for task in tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
# assure fetch block ids from D
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
|
||||
for task in tasks:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
while True:
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.")
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
for task in tasks:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
|
||||
for task in tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
# assure fetch block ids from D
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
for tmp_task in need_delete_tasks:
|
||||
tasks.remove(tmp_task)
|
||||
# release resource in P
|
||||
@@ -930,7 +947,7 @@ class EngineService:
|
||||
for request_id, contents in results.items():
|
||||
new_contents = []
|
||||
for content in contents:
|
||||
if isinstance(content, RequestOutput):
|
||||
if isinstance(content, RequestOutput) and content.outputs is not None:
|
||||
decode_type = content.outputs.decode_type
|
||||
delta_text = ""
|
||||
if decode_type == 0:
|
||||
@@ -1035,6 +1052,7 @@ class EngineService:
|
||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||
)
|
||||
continue
|
||||
self.token_processor.tokens_counter[task.request_id] = 1
|
||||
self.resource_manager.insert_task_for_decoding(task)
|
||||
|
||||
else:
|
||||
|
||||
@@ -27,6 +27,7 @@ import numpy as np
|
||||
|
||||
from fastdeploy.engine.common_engine import EngineService
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.utils import console_logger, envs, llm_logger
|
||||
|
||||
|
||||
@@ -99,6 +100,10 @@ class ExpertService:
|
||||
self.engine.start_zmq_service(ipc_signal_suffix)
|
||||
else:
|
||||
ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.internal_adapter = InternalAdapter(
|
||||
cfg=self.cfg, engine=self.engine, dp_rank=self.cfg.parallel_config.local_data_parallel_id
|
||||
)
|
||||
|
||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
|
||||
|
||||
@@ -151,6 +151,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_MOE_QUANT_TYPE": lambda: os.getenv("FD_MOE_QUANT_TYPE", "w4a8"),
|
||||
"ENCODE_FEATURE_BOS_AK": lambda: os.getenv("ENCODE_FEATURE_BOS_AK"),
|
||||
"ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"),
|
||||
# Enable offline perf test mode for PD disaggregation
|
||||
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -80,24 +80,79 @@ class EngineWorkerQueue:
|
||||
self.client_read_flag_init: List[List[int]] = [
|
||||
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
self.lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)]
|
||||
self.read_finish_flag_init: List[Value] = [Value("i", 0) for _ in range(self.local_data_parallel_size)]
|
||||
self.connected_client_counter_init: List[Value] = [
|
||||
Value("i", 0) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.finished_add_cache_task_queue = [Queue() for _ in range(self.local_data_parallel_size)]
|
||||
self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.finished_add_cache_task_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
|
||||
self.client_read_info_flag_init: List[List[int]] = [
|
||||
[1] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
[0] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.lock_info_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
# PD disaggregation
|
||||
# Locks
|
||||
self.connect_task_lock_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
] # connect rdma task
|
||||
self.connect_task_response_lock_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
] # connect rdma task response
|
||||
self.finish_add_cache_task_lock_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
] # finish add cache task
|
||||
self.finish_send_cache_lock_init: List[threading.Lock] = [
|
||||
threading.Lock() for _ in range(self.local_data_parallel_size)
|
||||
] # finish send cache
|
||||
|
||||
# sync read status for TPs
|
||||
self.client_get_connect_task_flag_init: List[List[int]] = [
|
||||
[0] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.client_get_connect_task_response_flag_init: List[List[int]] = [
|
||||
[0] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.client_get_finished_add_cache_task_flag_init: List[List[int]] = [
|
||||
[0] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.client_get_finish_send_cache_flag_init: List[List[int]] = [
|
||||
[0] * self.num_client for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.can_put_next_connect_task_response_flag_init: List[Value] = [
|
||||
Value("i", 1) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.can_put_next_add_task_finished_flag_init: List[Value] = [
|
||||
Value("i", 1) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.can_put_next_send_cache_finished_flag_init: List[Value] = [
|
||||
Value("i", 1) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
# barrier
|
||||
self.get_connect_task_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.get_connect_task_response_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.finish_add_cache_task_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.begin_send_cache_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.finish_send_cache_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.get_cache_info_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
self.finish_request_barrier = [
|
||||
@@ -107,10 +162,6 @@ class EngineWorkerQueue:
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
self.finish_add_cache_task_barrier = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
# Register shared objects with proxy types
|
||||
QueueManager.register(
|
||||
"get_tasks",
|
||||
@@ -122,6 +173,26 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.client_read_flag_init[idx],
|
||||
proxytype=ListProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_client_get_connect_task_flag",
|
||||
callable=lambda idx: self.client_get_connect_task_flag_init[idx],
|
||||
proxytype=ListProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_client_get_connect_task_response_flag",
|
||||
callable=lambda idx: self.client_get_connect_task_response_flag_init[idx],
|
||||
proxytype=ListProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_client_get_finished_add_cache_task_flag_init",
|
||||
callable=lambda idx: self.client_get_finished_add_cache_task_flag_init[idx],
|
||||
proxytype=ListProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_client_get_finish_send_cache_flag_init",
|
||||
callable=lambda idx: self.client_get_finish_send_cache_flag_init[idx],
|
||||
proxytype=ListProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_lock",
|
||||
callable=lambda idx: self.lock_init[idx],
|
||||
@@ -132,11 +203,43 @@ class EngineWorkerQueue:
|
||||
callable=lambda idx: self.read_finish_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_can_put_next_connect_task_response_flag",
|
||||
callable=lambda idx: self.can_put_next_connect_task_response_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_can_put_next_add_task_finished_flag",
|
||||
callable=lambda idx: self.can_put_next_add_task_finished_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_can_put_next_send_cache_finished_flag",
|
||||
callable=lambda idx: self.can_put_next_send_cache_finished_flag_init[idx],
|
||||
proxytype=ValueProxy,
|
||||
)
|
||||
# PD disaggregation
|
||||
QueueManager.register(
|
||||
"get_connect_task_lock",
|
||||
callable=lambda idx: self.connect_task_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_task_response_lock",
|
||||
callable=lambda idx: self.connect_task_response_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_finish_add_cache_task_lock",
|
||||
callable=lambda idx: self.finish_add_cache_task_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_finish_send_cache_lock",
|
||||
callable=lambda idx: self.finish_send_cache_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_connect_rdma_tasks", callable=lambda idx: self.connect_rdma_tasks_list[idx], proxytype=ListProxy
|
||||
)
|
||||
@@ -152,13 +255,13 @@ class EngineWorkerQueue:
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_finish_request_queue",
|
||||
callable=lambda idx: self.finished_req_queue[idx],
|
||||
"get_finish_request_queue", callable=lambda idx: self.finished_req_list[idx], proxytype=ListProxy
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_finish_add_cache_task_queue",
|
||||
callable=lambda idx: self.finished_add_cache_task_queue[idx],
|
||||
callable=lambda idx: self.finished_add_cache_task_list[idx],
|
||||
proxytype=ListProxy,
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
@@ -194,6 +297,26 @@ class EngineWorkerQueue:
|
||||
"get_finish_request_barrier",
|
||||
callable=lambda idx: self.finish_request_barrier[idx],
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_task_barrier",
|
||||
callable=lambda idx: self.get_connect_task_barrier[idx],
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_connect_task_response_barrier",
|
||||
callable=lambda idx: self.get_connect_task_response_barrier[idx],
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_begin_send_cache_barrier",
|
||||
callable=lambda idx: self.begin_send_cache_barrier[idx],
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_finish_send_cache_barrier",
|
||||
callable=lambda idx: self.finish_send_cache_barrier[idx],
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_cache_info_barrier",
|
||||
callable=lambda idx: self.get_cache_info_barrier[idx],
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_finish_add_cache_task_barrier",
|
||||
@@ -231,10 +354,25 @@ class EngineWorkerQueue:
|
||||
QueueManager.register("get_available_prefill_instances")
|
||||
QueueManager.register("get_finish_request_barrier")
|
||||
QueueManager.register("get_finish_add_cache_task_barrier")
|
||||
QueueManager.register("get_connect_task_barrier")
|
||||
QueueManager.register("get_connect_task_response_barrier")
|
||||
QueueManager.register("get_finish_send_cache_barrier")
|
||||
QueueManager.register("get_begin_send_cache_barrier")
|
||||
QueueManager.register("get_cache_info_barrier")
|
||||
QueueManager.register("get_connect_rdma_tasks")
|
||||
QueueManager.register("get_client_get_connect_task_flag")
|
||||
QueueManager.register("get_client_get_connect_task_response_flag")
|
||||
QueueManager.register("get_client_get_finished_add_cache_task_flag_init")
|
||||
QueueManager.register("get_client_get_finish_send_cache_flag_init")
|
||||
QueueManager.register("get_connect_rdma_tasks_responses")
|
||||
QueueManager.register("get_connect_task_lock")
|
||||
QueueManager.register("get_connect_task_response_lock")
|
||||
QueueManager.register("get_finish_add_cache_task_lock")
|
||||
QueueManager.register("get_finish_send_cache_lock")
|
||||
QueueManager.register("get_worker_process_tp_barrier")
|
||||
QueueManager.register("get_can_put_next_connect_task_response_flag")
|
||||
QueueManager.register("get_can_put_next_add_task_finished_flag")
|
||||
QueueManager.register("get_can_put_next_send_cache_finished_flag")
|
||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self._connect_with_retry()
|
||||
|
||||
@@ -257,17 +395,50 @@ class EngineWorkerQueue:
|
||||
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.connect_task_barrier = self.manager.get_connect_task_barrier(self.local_data_parallel_id)
|
||||
self.connect_task_response_barrier = self.manager.get_connect_task_response_barrier(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.finish_send_cache_barrier = self.manager.get_finish_send_cache_barrier(self.local_data_parallel_id)
|
||||
self.cache_info_barrier = self.manager.get_cache_info_barrier(self.local_data_parallel_id)
|
||||
self.begin_send_cache_barrier = self.manager.get_begin_send_cache_barrier(self.local_data_parallel_id)
|
||||
self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id)
|
||||
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
|
||||
self.finished_send_cache_list = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||
self.finished_add_cache_task_list = self.manager.get_finish_add_cache_task_queue(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
# p/d互联
|
||||
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
|
||||
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
|
||||
self.connect_rdma_tasks = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
|
||||
self.client_get_connect_task_flag = self.manager.get_client_get_connect_task_flag(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.client_get_connect_task_response_flag = self.manager.get_client_get_connect_task_response_flag(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.client_get_finished_add_cache_task_flag = (
|
||||
self.manager.get_client_get_finished_add_cache_task_flag_init(self.local_data_parallel_id)
|
||||
)
|
||||
self.client_get_finish_send_cache_flag = self.manager.get_client_get_finish_send_cache_flag_init(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
|
||||
self.connect_rdma_task_responses = self.manager.get_connect_rdma_tasks_responses(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
|
||||
self.connect_task_response_lock = self.manager.get_connect_task_response_lock(self.local_data_parallel_id)
|
||||
self.finish_add_cache_task_lock = self.manager.get_finish_add_cache_task_lock(self.local_data_parallel_id)
|
||||
self.finish_send_cache_lock = self.manager.get_finish_send_cache_lock(self.local_data_parallel_id)
|
||||
|
||||
self.can_put_next_add_task_finished_flag = self.manager.get_can_put_next_add_task_finished_flag(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.can_put_next_connect_task_response_flag = self.manager.get_can_put_next_connect_task_response_flag(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
self.can_put_next_send_cache_finished_flag = self.manager.get_can_put_next_send_cache_finished_flag(
|
||||
self.local_data_parallel_id
|
||||
)
|
||||
|
||||
assert self.num_client == len(self.client_read_flag)
|
||||
|
||||
@@ -411,41 +582,61 @@ class EngineWorkerQueue:
|
||||
|
||||
def put_connect_rdma_task(self, connect_rdma_task):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_queue.append(connect_rdma_task)
|
||||
while sum(self.client_get_connect_task_flag) < self.num_client:
|
||||
self.connect_task_lock.release()
|
||||
time.sleep(0.001)
|
||||
self.connect_task_lock.acquire()
|
||||
|
||||
self.connect_rdma_tasks[:] = list()
|
||||
self.client_get_connect_task_flag[:] = [0] * self.num_client
|
||||
self.connect_rdma_tasks.append(connect_rdma_task)
|
||||
self.connect_task_lock.release()
|
||||
|
||||
def get_connect_rdma_task(self):
|
||||
result = None
|
||||
connect_rdma_task = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
if len(self.connect_rdma_tasks) > 0:
|
||||
connect_rdma_task = self.connect_rdma_tasks[0]
|
||||
self.client_get_connect_task_flag[self.client_id] = 1
|
||||
all_client_read: bool = np.sum(self.client_get_connect_task_flag) == self.num_client
|
||||
if all_client_read:
|
||||
self.connect_rdma_tasks[:] = list()
|
||||
self.connect_task_lock.release()
|
||||
return connect_rdma_task, all_client_read
|
||||
|
||||
def put_connect_rdma_task_response(self, connect_rdma_task_response):
|
||||
self.connect_task_lock.acquire()
|
||||
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
|
||||
self.connect_task_lock.release()
|
||||
self.connect_task_response_lock.acquire()
|
||||
while not self.can_put_next_connect_task_response_flag.get():
|
||||
self.connect_task_response_lock.release()
|
||||
time.sleep(0.001)
|
||||
self.connect_task_response_lock.acquire()
|
||||
self.connect_rdma_task_responses.append(connect_rdma_task_response)
|
||||
self.client_get_connect_task_response_flag[self.client_id] = 1
|
||||
all_client_put: bool = np.sum(self.client_get_connect_task_response_flag) == self.num_client
|
||||
if all_client_put:
|
||||
self.can_put_next_connect_task_response_flag.set(0)
|
||||
self.connect_task_response_lock.release()
|
||||
return all_client_put
|
||||
|
||||
def get_connect_rdma_task_response(self):
|
||||
result = None
|
||||
self.connect_task_lock.acquire()
|
||||
if len(self.connect_rdma_task_response_queue) == 0:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
try:
|
||||
result = self.connect_rdma_task_response_queue.pop(0)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
|
||||
finally:
|
||||
self.connect_task_lock.release()
|
||||
return result
|
||||
task_response = None
|
||||
self.connect_task_response_lock.acquire()
|
||||
if len(self.connect_rdma_task_responses) == 0:
|
||||
self.connect_task_response_lock.release()
|
||||
return task_response
|
||||
while sum(self.client_get_connect_task_response_flag) < self.num_client:
|
||||
self.connect_task_response_lock.release()
|
||||
time.sleep(0.001)
|
||||
self.connect_task_response_lock.acquire()
|
||||
if len(self.connect_rdma_task_responses) > 0:
|
||||
task_response = self.connect_rdma_task_responses[0]
|
||||
for tmp_task_response in self.connect_rdma_task_responses:
|
||||
task_response["success"] = task_response["success"] and tmp_task_response["success"]
|
||||
self.connect_rdma_task_responses[:] = list()
|
||||
self.client_get_connect_task_response_flag[:] = [0] * self.num_client
|
||||
self.can_put_next_connect_task_response_flag.set(1)
|
||||
self.connect_task_response_lock.release()
|
||||
return task_response
|
||||
|
||||
def get_prefill_instances(self):
|
||||
"""
|
||||
@@ -508,14 +699,25 @@ class EngineWorkerQueue:
|
||||
self.lock_info.release()
|
||||
return total_num
|
||||
|
||||
def put_finished_req(self, req_ids) -> None:
|
||||
def put_finished_req(self, send_cache_result) -> None:
|
||||
"""
|
||||
Put finished request ID into the queue.
|
||||
|
||||
Args:
|
||||
req_ids: Request ID to be added to the queue
|
||||
"""
|
||||
self.finished_req_queue.put(req_ids)
|
||||
self.finish_send_cache_lock.acquire()
|
||||
while not self.can_put_next_send_cache_finished_flag.get():
|
||||
self.finish_send_cache_lock.release()
|
||||
time.sleep(0.001)
|
||||
self.finish_send_cache_lock.acquire()
|
||||
self.finished_send_cache_list.append(send_cache_result[0])
|
||||
self.client_get_finish_send_cache_flag[self.client_id] = 1
|
||||
all_client_put: bool = np.sum(self.client_get_finish_send_cache_flag) == self.num_client
|
||||
if all_client_put:
|
||||
self.can_put_next_send_cache_finished_flag.set(0)
|
||||
self.finish_send_cache_lock.release()
|
||||
return all_client_put
|
||||
|
||||
def get_finished_req(self) -> str:
|
||||
"""
|
||||
@@ -524,12 +726,27 @@ class EngineWorkerQueue:
|
||||
Returns:
|
||||
str: Finished request ID
|
||||
"""
|
||||
ans = []
|
||||
if self.finished_req_queue.empty():
|
||||
return ans
|
||||
ans = self.finished_req_queue.get()
|
||||
llm_logger.debug(f"get finished req: {ans}")
|
||||
return ans
|
||||
response = []
|
||||
self.finish_send_cache_lock.acquire()
|
||||
if len(self.finished_send_cache_list) == 0:
|
||||
self.finish_send_cache_lock.release()
|
||||
return response
|
||||
while sum(self.client_get_finish_send_cache_flag) < self.num_client:
|
||||
self.finish_send_cache_lock.release()
|
||||
time.sleep(0.001)
|
||||
self.finish_send_cache_lock.acquire()
|
||||
if len(self.finished_send_cache_list) > 0:
|
||||
response = self.finished_send_cache_list[0]
|
||||
for tmp_response in self.finished_send_cache_list:
|
||||
if "error" in tmp_response[1]:
|
||||
response[1] = tmp_response[1]
|
||||
if response:
|
||||
response = [response]
|
||||
self.finished_send_cache_list[:] = list()
|
||||
self.client_get_finish_send_cache_flag[:] = [0] * self.num_client
|
||||
self.can_put_next_send_cache_finished_flag.set(1)
|
||||
self.finish_send_cache_lock.release()
|
||||
return response
|
||||
|
||||
def put_finished_add_cache_task_req(self, req_ids) -> None:
|
||||
"""
|
||||
@@ -538,7 +755,18 @@ class EngineWorkerQueue:
|
||||
Args:
|
||||
req_ids: Request ID to be added to the queue
|
||||
"""
|
||||
self.finished_add_cache_task_queue.put(req_ids)
|
||||
self.finish_add_cache_task_lock.acquire()
|
||||
while not self.can_put_next_add_task_finished_flag.get():
|
||||
self.finish_add_cache_task_lock.release()
|
||||
time.sleep(0.001)
|
||||
self.finish_add_cache_task_lock.acquire()
|
||||
self.finished_add_cache_task_list.append(req_ids)
|
||||
self.client_get_finished_add_cache_task_flag[self.client_id] = 1
|
||||
all_client_put: bool = np.sum(self.client_get_finished_add_cache_task_flag) == self.num_client
|
||||
if all_client_put:
|
||||
self.can_put_next_add_task_finished_flag.set(0)
|
||||
self.finish_add_cache_task_lock.release()
|
||||
return all_client_put
|
||||
|
||||
def get_finished_add_cache_task_req(self) -> str:
|
||||
"""
|
||||
@@ -547,12 +775,24 @@ class EngineWorkerQueue:
|
||||
Returns:
|
||||
str: Finished request ID
|
||||
"""
|
||||
ans = []
|
||||
if self.finished_add_cache_task_queue.empty():
|
||||
return ans
|
||||
ans = self.finished_add_cache_task_queue.get()
|
||||
llm_logger.debug(f"get finished req: {ans}")
|
||||
return ans
|
||||
response = []
|
||||
self.finish_add_cache_task_lock.acquire()
|
||||
if len(self.finished_add_cache_task_list) == 0:
|
||||
self.finish_add_cache_task_lock.release()
|
||||
return response
|
||||
while sum(self.client_get_finished_add_cache_task_flag) < self.num_client:
|
||||
self.finish_add_cache_task_lock.release()
|
||||
time.sleep(0.001)
|
||||
self.finish_add_cache_task_lock.acquire()
|
||||
if len(self.finished_add_cache_task_list) > 0:
|
||||
response = self.finished_add_cache_task_list[0]
|
||||
for tmp_response in self.finished_add_cache_task_list:
|
||||
assert tmp_response == response
|
||||
self.finished_add_cache_task_list[:] = list()
|
||||
self.client_get_finished_add_cache_task_flag[:] = [0] * self.num_client
|
||||
self.can_put_next_add_task_finished_flag.set(1)
|
||||
self.finish_add_cache_task_lock.release()
|
||||
return response
|
||||
|
||||
def disaggregate_queue_empty(self):
|
||||
"""
|
||||
|
||||
@@ -143,38 +143,7 @@ class DPLocalScheduler(LocalScheduler):
|
||||
requests: List[Request] = []
|
||||
|
||||
with self.requests_not_empty:
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
while True:
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||
0.005,
|
||||
)
|
||||
if batch_ids:
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(
|
||||
request.prompt_tokens_ids_len, block_size
|
||||
)
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
start_batch_time = time.time()
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
if len(requests) >= batch:
|
||||
break
|
||||
if (
|
||||
(current_prefill_tokens > max_num_batched_tokens)
|
||||
or (len(requests) >= batch)
|
||||
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
|
||||
):
|
||||
break
|
||||
else:
|
||||
required_total_blocks = 0
|
||||
while True:
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||
0.005,
|
||||
@@ -183,11 +152,24 @@ class DPLocalScheduler(LocalScheduler):
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
start_batch_time = time.time()
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
if len(requests) >= batch:
|
||||
break
|
||||
if (
|
||||
(current_prefill_tokens > max_num_batched_tokens)
|
||||
or (len(requests) >= batch)
|
||||
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
|
||||
):
|
||||
break
|
||||
|
||||
if batch_ids:
|
||||
if len(batch_ids) > 0 and len(requests) == 0:
|
||||
|
||||
@@ -78,7 +78,7 @@ class InternalAdapter:
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"Recieve control task: {task}")
|
||||
logger.info(f"dprank {self.dp_rank} Recieve control task: {task}")
|
||||
task_id_str = task["task_id"]
|
||||
if task["cmd"] == "get_payload":
|
||||
payload_info = self._get_current_server_info()
|
||||
|
||||
@@ -275,6 +275,7 @@ class SplitwiseConnector:
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
|
||||
task.disaggregate_info["role"] = "decode"
|
||||
self._send_message(addr, "prefill", [task])
|
||||
task.disaggregate_info["cache_info"] = decode_diagg
|
||||
task.disaggregate_info["role"] = "prefill"
|
||||
|
||||
@@ -177,7 +177,7 @@ class PaddleDisWorkerProc:
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
launched_expert_service_signal_data = np.zeros(
|
||||
shape=[min(self.parallel_config.data_parallel_size, self.max_chips_per_node)], dtype=np.int32
|
||||
shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32
|
||||
)
|
||||
self.launched_expert_service_signal = IPCSignal(
|
||||
name="launched_expert_service_signal",
|
||||
@@ -186,7 +186,12 @@ class PaddleDisWorkerProc:
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
|
||||
while (
|
||||
self.launched_expert_service_signal.value[
|
||||
self.parallel_config.local_data_parallel_id % self.max_chips_per_node
|
||||
]
|
||||
== 0
|
||||
):
|
||||
pass
|
||||
|
||||
# init worker_ready_signal
|
||||
@@ -568,7 +573,7 @@ class PaddleDisWorkerProc:
|
||||
is_server=False,
|
||||
num_client=self.parallel_config.tensor_parallel_size,
|
||||
client_id=self.parallel_config.tensor_parallel_rank,
|
||||
local_data_parallel_id=self.parallel_config.data_parallel_rank,
|
||||
local_data_parallel_id=self.parallel_config.local_data_parallel_id,
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user