mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[PD Disaggregation] Unify the disaggregation info and the pd communication (#5438)
* Unify the disaggregation info and the pd communication * up * up * fix * fix conflict * fix unittest
This commit is contained in:
@@ -96,6 +96,19 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def get_decode_ip_idx(task):
|
||||
"""For compatibility, get decode ip and idx from task"""
|
||||
if "decode_ip" in task:
|
||||
decode_ip = task["decode_ip"]
|
||||
else:
|
||||
decode_ip = task["ip"]
|
||||
if "decode_rdma_ports" in task:
|
||||
decode_rdma_ports = task["decode_rdma_ports"]
|
||||
else:
|
||||
decode_rdma_ports = task["rdma_ports"]
|
||||
return decode_ip, decode_rdma_ports
|
||||
|
||||
|
||||
class CacheMessager:
|
||||
"""
|
||||
CacheMessager is used to send the cache data between the engine worker and the cache server.
|
||||
@@ -282,6 +295,7 @@ class CacheMessager:
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
else:
|
||||
self.cache_info[info["request_id"]] = info
|
||||
|
||||
prefilled_layer_idx = layer_shm_value.value[0]
|
||||
prefilled_step_idx = step_shm_value.value[0]
|
||||
if prefilled_layer_idx == self.num_layers - 1:
|
||||
@@ -316,15 +330,18 @@ class CacheMessager:
|
||||
continue
|
||||
current_transfer_protocol = item["transfer_protocol"]
|
||||
if item["transfer_protocol"] == "rdma":
|
||||
target_ip = item["ip"]
|
||||
target_id = int(item["rdma_ports"][self.rank])
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
decode_ip, decode_rdma_ports = get_decode_ip_idx(item)
|
||||
decode_idx = int(decode_rdma_ports[self.rank])
|
||||
status = self.messager[current_transfer_protocol].connect(decode_ip, decode_idx)
|
||||
if not status:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
logger.error(f"connect to {decode_ip}:{decode_idx} failed")
|
||||
item["status"] = "connect error"
|
||||
elif item["transfer_protocol"] == "ipc":
|
||||
target_ip = "0.0.0.0"
|
||||
target_id = int(item["device_ids"][self.rank])
|
||||
decode_ip = "0.0.0.0"
|
||||
decode_device_ids = (
|
||||
item["decode_device_ids"] if "decode_device_ids" in item else item["device_ids"]
|
||||
)
|
||||
decode_idx = int(decode_device_ids[self.rank])
|
||||
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
|
||||
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
|
||||
if item["current_id"] < prefilled_step_idx:
|
||||
@@ -335,8 +352,8 @@ class CacheMessager:
|
||||
for layer_idx in range(item["layer_idx"], current_layer_idx):
|
||||
tic = time.time()
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
decode_ip,
|
||||
decode_idx,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
@@ -345,7 +362,7 @@ class CacheMessager:
|
||||
item["status"] = "write cache error"
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
f"req_id: {item['request_id']}, dest_ip: {decode_ip}"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -365,7 +382,7 @@ class CacheMessager:
|
||||
if "error" not in item["status"]:
|
||||
item["status"] = "finished"
|
||||
if item["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
self.messager["ipc"].write_block_by_sync(decode_idx)
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_send_cache_barrier.wait()
|
||||
self.engine_worker_queue.put_finished_req([[item["request_id"], item["status"]]])
|
||||
@@ -387,8 +404,9 @@ class CacheMessager:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
decode_ip, decode_rdma_ports = get_decode_ip_idx(task)
|
||||
rdma_port = decode_rdma_ports[self.rank]
|
||||
status = self.messager["rdma"].connect(decode_ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
@@ -634,6 +652,7 @@ class CacheMessagerV1:
|
||||
end_layer_idx = prefilled_layer_idx
|
||||
if sended_layer_idx == prefilled_layer_idx: # computation not in next layer
|
||||
time.sleep(0.01)
|
||||
|
||||
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
|
||||
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
|
||||
engine_index = batch_engine_signals[i][0]
|
||||
@@ -650,13 +669,13 @@ class CacheMessagerV1:
|
||||
else:
|
||||
current_transfer_protocol = task["transfer_protocol"]
|
||||
if task["transfer_protocol"] == "rdma":
|
||||
target_ip = task["ip"]
|
||||
decode_ip, decode_rdma_ports = get_decode_ip_idx(task)
|
||||
# Default decode_tp_size to prefill tp_size (self.nranks) if not specified
|
||||
decode_tp_size = task.get("decode_tp_size", self.nranks)
|
||||
if len(task["rdma_ports"]) == self.nranks:
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
elif len(task["rdma_ports"]) == 1:
|
||||
target_id = task["rdma_ports"][0]
|
||||
if len(decode_rdma_ports) == self.nranks:
|
||||
decode_idx = int(decode_rdma_ports[self.rank])
|
||||
elif len(decode_rdma_ports) == 1:
|
||||
decode_idx = decode_rdma_ports[0]
|
||||
else:
|
||||
task["status"] = "the tp_size of prefill and decode is mismatch"
|
||||
continue
|
||||
@@ -666,21 +685,26 @@ class CacheMessagerV1:
|
||||
|
||||
# TODO: use is connected to check if the connection is still alive
|
||||
logger.debug(
|
||||
f"rdma, start connect decode, {target_ip}:{target_id}, "
|
||||
f"rdma, start connect decode, {decode_ip}:{decode_idx}, "
|
||||
f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}"
|
||||
)
|
||||
status = self.messager[current_transfer_protocol].connect(
|
||||
target_ip, target_id, decode_tp_size
|
||||
decode_ip, decode_idx, decode_tp_size
|
||||
)
|
||||
if status:
|
||||
logger.debug(f"connect to {target_ip}:{target_id} success")
|
||||
logger.debug(f"connect to {decode_ip}:{decode_idx} success")
|
||||
else:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
logger.error(f"connect to {decode_ip}:{decode_idx} failed")
|
||||
task["status"] = "connection error"
|
||||
continue
|
||||
elif task["transfer_protocol"] == "ipc":
|
||||
target_ip = "0.0.0.0"
|
||||
target_id = int(task["device_ids"][self.rank])
|
||||
decode_device_ids = (
|
||||
task["decode_device_ids"]
|
||||
if "decode_device_ids" in task
|
||||
else task["device_ids"]
|
||||
)
|
||||
decode_ip = "0.0.0.0"
|
||||
decode_idx = int(decode_device_ids[self.rank])
|
||||
|
||||
src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
|
||||
dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
|
||||
@@ -688,12 +712,12 @@ class CacheMessagerV1:
|
||||
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
|
||||
|
||||
logger.info(
|
||||
f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
f"start write cache for a layer, {req_id}, {layer_idx}, {decode_ip}, {decode_idx}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
tic = time.time()
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
decode_ip,
|
||||
decode_idx,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
@@ -701,7 +725,7 @@ class CacheMessagerV1:
|
||||
if return_code != 0:
|
||||
task["status"] = "write cache error"
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {decode_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
tok = time.time()
|
||||
cost_time = tok - tic
|
||||
@@ -709,7 +733,7 @@ class CacheMessagerV1:
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
|
||||
f"finish write cache for a layer, {req_id}, {layer_idx}, {decode_ip}, {decode_idx},"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
|
||||
)
|
||||
@@ -734,8 +758,13 @@ class CacheMessagerV1:
|
||||
task = self.idx_cache_task_dict[engine_idx]
|
||||
if task["status"] == "finished" or ("error" in task["status"]):
|
||||
if task["transfer_protocol"] == "ipc":
|
||||
target_id = int(task["device_ids"][self.rank])
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
decode_device_ids = (
|
||||
task["decode_device_ids"]
|
||||
if "decode_device_ids" in task
|
||||
else task["device_ids"]
|
||||
)
|
||||
decode_idx = int(decode_device_ids[self.rank])
|
||||
self.messager["ipc"].write_block_by_sync(decode_idx)
|
||||
self.engine_worker_queue.finish_send_cache_barrier.wait()
|
||||
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
|
||||
logger.info(
|
||||
@@ -796,18 +825,17 @@ class CacheMessagerV1:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip = task["ip"]
|
||||
decode_ip, decode_rdma_ports = get_decode_ip_idx(task)
|
||||
# Default decode_tp_size to self.nranks (number of ranks) if not specified in the task.
|
||||
decode_tp_size = task.get("decode_tp_size", self.nranks)
|
||||
rdma_ports = task["rdma_ports"]
|
||||
rdma_ports_len = len(rdma_ports)
|
||||
rdma_ports_len = len(decode_rdma_ports)
|
||||
if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks):
|
||||
# TODO: support other cases
|
||||
logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}")
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank]
|
||||
status = self.messager["rdma"].connect(ip, port, decode_tp_size)
|
||||
port = decode_rdma_ports[0] if rdma_ports_len == 1 else decode_rdma_ports[self.rank]
|
||||
status = self.messager["rdma"].connect(decode_ip, port, decode_tp_size)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
|
||||
@@ -1919,42 +1919,23 @@ class FDConfig:
|
||||
else None
|
||||
)
|
||||
|
||||
self.disaggregate_info = {}
|
||||
if self.scheduler_config.splitwise_role != "mixed":
|
||||
self.disaggregate_info["role"] = self.scheduler_config.splitwise_role
|
||||
self.disaggregate_info["cache_info"] = dict()
|
||||
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
|
||||
self.disaggregate_info["transfer_protocol"] = current_protocol
|
||||
|
||||
for protocol in current_protocol:
|
||||
if protocol == "ipc":
|
||||
self.disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
}
|
||||
elif protocol == "rdma":
|
||||
self.disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": connector_port,
|
||||
"rdma_port": self.cache_config.rdma_comm_ports,
|
||||
}
|
||||
logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
|
||||
if self.router_config:
|
||||
# the information for registering this server to router
|
||||
self.register_info = {
|
||||
"role": self.scheduler_config.splitwise_role,
|
||||
"host_ip": self.host_ip,
|
||||
"port": self.router_config.api_server_port,
|
||||
"connector_port": connector_port,
|
||||
"rdma_ports": self.cache_config.rdma_comm_ports,
|
||||
"engine_worker_queue_port": engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
|
||||
"tp_size": self.parallel_config.tensor_parallel_size,
|
||||
}
|
||||
logger.info(f"register_info: {self.register_info}")
|
||||
# the information for registering this server to router or splitwise_scheduler
|
||||
port = self.router_config.api_server_port if self.router_config else None
|
||||
transfer_protocol = (
|
||||
self.cache_config.cache_transfer_protocol.split(",") if self.cache_config.cache_transfer_protocol else []
|
||||
)
|
||||
self.register_info = {
|
||||
"role": self.scheduler_config.splitwise_role,
|
||||
"host_ip": self.host_ip,
|
||||
"port": port,
|
||||
"connector_port": connector_port,
|
||||
"rdma_ports": self.cache_config.rdma_comm_ports,
|
||||
"engine_worker_queue_port": engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
"transfer_protocol": transfer_protocol,
|
||||
"tp_size": self.parallel_config.tensor_parallel_size,
|
||||
}
|
||||
logger.info(f"register_info: {self.register_info}")
|
||||
|
||||
def read_from_config(self):
|
||||
"""
|
||||
|
||||
@@ -424,7 +424,7 @@ class EngineService:
|
||||
|
||||
need_delete_tasks = []
|
||||
for task in tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if status:
|
||||
task.metrics.ask_decode_resource_finish_time = time.time()
|
||||
@@ -469,7 +469,7 @@ class EngineService:
|
||||
is_prefill = False
|
||||
for i in range(len(tasks)):
|
||||
if tasks[i].disaggregate_info is not None:
|
||||
if tasks[i].disaggregate_info["role"] == "decode":
|
||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
is_decode = True
|
||||
else:
|
||||
is_prefill = True
|
||||
@@ -811,11 +811,10 @@ class EngineService:
|
||||
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
|
||||
)
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
for task in tasks:
|
||||
# start async preprocess
|
||||
self.resource_manager.apply_async_preprocess(task)
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
for task in tasks:
|
||||
# start async preprocess
|
||||
self.resource_manager.apply_async_preprocess(task)
|
||||
need_delete_tasks = []
|
||||
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
|
||||
for task in tasks:
|
||||
@@ -873,7 +872,6 @@ class EngineService:
|
||||
# release resource in P
|
||||
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
# to send cache info to cache messager
|
||||
if tasks:
|
||||
need_check_req_ids = [task.request_id for task in tasks]
|
||||
@@ -912,6 +910,7 @@ class EngineService:
|
||||
tasks.remove(tmp_task)
|
||||
# release resource in P
|
||||
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
|
||||
|
||||
# Fetch requests and add them to the scheduling queue
|
||||
if tasks:
|
||||
for task in tasks:
|
||||
@@ -1765,11 +1764,10 @@ class EngineService:
|
||||
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
request_queues_for_dp_ipc = None
|
||||
result_queue_for_dp_ipc = None
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
self.scheduler.start(role, host_ip, self.cfg.register_info)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queue_for_dp_ipc = multiprocessing.Queue()
|
||||
|
||||
@@ -715,11 +715,10 @@ class LLMEngine:
|
||||
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
request_queues_for_dp_ipc = None
|
||||
result_queues_for_dp_ipc = None
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.engine.scheduler.start(role, host_ip, disaggregate)
|
||||
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queues_for_dp_ipc = []
|
||||
|
||||
@@ -113,8 +113,7 @@ class ExpertService:
|
||||
self.cfg.init_cache_info()
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
self.engine.scheduler.start(role, host_ip, disaggregate)
|
||||
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.splitwise_receive_thread = threading.Thread(
|
||||
|
||||
@@ -188,26 +188,15 @@ class Router:
|
||||
is_same_tp_size = prefill_server.tp_size == decode_server.tp_size
|
||||
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
|
||||
|
||||
cache_info = {}
|
||||
if use_ipc:
|
||||
cache_info["ipc"] = {
|
||||
"ip": decode_server.host_ip,
|
||||
"port": decode_server.engine_worker_queue_port,
|
||||
"device_ids": decode_server.device_ids,
|
||||
}
|
||||
else:
|
||||
cache_info["rdma"] = {
|
||||
"ip": decode_server.host_ip,
|
||||
"port": decode_server.connector_port,
|
||||
"rdma_port": decode_server.rdma_ports,
|
||||
}
|
||||
|
||||
disaggregate_info = {
|
||||
"prefill": prefill_server.to_dict(),
|
||||
"decode": decode_server.to_dict(),
|
||||
"role": "decode",
|
||||
"cache_info": cache_info,
|
||||
"prefill_ip": prefill_server.host_ip,
|
||||
"decode_ip": decode_server.host_ip,
|
||||
"prefill_connector_port": prefill_server.connector_port,
|
||||
"decode_connector_port": decode_server.connector_port,
|
||||
"decode_device_ids": decode_server.device_ids,
|
||||
"decode_rdma_ports": decode_server.rdma_ports,
|
||||
"transfer_protocol": "ipc" if use_ipc else "rdma",
|
||||
"decode_tp_size": decode_server.tp_size,
|
||||
}
|
||||
|
||||
modified_request = request_data.copy()
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import math
|
||||
import pickle
|
||||
@@ -533,16 +532,26 @@ class APIScheduler:
|
||||
else:
|
||||
dnodes.sort()
|
||||
dnode = self.select_pd(req, dnodes, "decode")
|
||||
disaggregated = copy.deepcopy(dnode.disaggregated)
|
||||
transfer_protocol = disaggregated["transfer_protocol"]
|
||||
if len(transfer_protocol) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
|
||||
if pnode.host == dnode.host:
|
||||
disaggregated["transfer_protocol"] = "ipc"
|
||||
else:
|
||||
disaggregated["transfer_protocol"] = "rdma"
|
||||
else:
|
||||
disaggregated["transfer_protocol"] = transfer_protocol[0]
|
||||
req.disaggregate_info = disaggregated
|
||||
|
||||
is_same_node = pnode.disaggregated["host_ip"] == dnode.disaggregated["host_ip"]
|
||||
is_support_ipc = (
|
||||
"ipc" in pnode.disaggregated["transfer_protocol"] and "ipc" in dnode.disaggregated["transfer_protocol"]
|
||||
)
|
||||
is_same_tp_size = pnode.disaggregated["tp_size"] == dnode.disaggregated["tp_size"]
|
||||
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
|
||||
|
||||
disaggregate_info = {
|
||||
"prefill_ip": pnode.disaggregated["host_ip"],
|
||||
"decode_ip": dnode.disaggregated["host_ip"],
|
||||
"prefill_connector_port": pnode.disaggregated["connector_port"],
|
||||
"decode_connector_port": dnode.disaggregated["connector_port"],
|
||||
"decode_device_ids": dnode.disaggregated["device_ids"],
|
||||
"decode_rdma_ports": dnode.disaggregated["rdma_ports"],
|
||||
"transfer_protocol": "ipc" if use_ipc else "rdma",
|
||||
"decode_tp_size": dnode.disaggregated["tp_size"],
|
||||
}
|
||||
|
||||
req.disaggregate_info = disaggregate_info
|
||||
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
|
||||
req_dict = req.to_dict()
|
||||
req_dict["group"] = group
|
||||
|
||||
@@ -24,7 +24,6 @@ import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
@@ -53,7 +52,6 @@ class SplitwiseConnector:
|
||||
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
self.engine_worker_queue = worker_queue
|
||||
self.resource_manager = resource_manager
|
||||
self.connect_innode_instances = {}
|
||||
self.current_request_ids = dict()
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
@@ -172,119 +170,59 @@ class SplitwiseConnector:
|
||||
|
||||
def send_splitwise_tasks(self, tasks: List[Request], current_id):
|
||||
"""
|
||||
Send splitwise tasks to all connected addresses.
|
||||
Prefill send splitwise tasks to decode.
|
||||
|
||||
Parameters:
|
||||
tasks (list): List of tasks.
|
||||
current_id (int): Current ID.
|
||||
"""
|
||||
addr = None
|
||||
decode_diagg = None
|
||||
for task in tasks:
|
||||
if task.disaggregate_info is None:
|
||||
continue
|
||||
|
||||
if task.disaggregate_info["transfer_protocol"] == "ipc":
|
||||
addr = task.disaggregate_info["cache_info"]["ipc"]["port"]
|
||||
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
|
||||
self.logger.info(f"send_splitwise_tasks: protocol=ipc, addr={addr}, task={task.request_id}")
|
||||
self.send_splitwise_tasks_innode([task], addr)
|
||||
else:
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
task.disaggregate_info["role"] = "decode"
|
||||
addr = f"{task.disaggregate_info['decode_ip']}:{task.disaggregate_info['decode_connector_port']}"
|
||||
self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}")
|
||||
self._send_message(addr, "prefill", [task])
|
||||
|
||||
addr = (
|
||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
|
||||
task.disaggregate_info["role"] = "decode"
|
||||
self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}")
|
||||
self._send_message(addr, "prefill", [task])
|
||||
task.disaggregate_info["cache_info"] = decode_diagg
|
||||
task.disaggregate_info["role"] = "prefill"
|
||||
|
||||
def send_splitwise_tasks_innode(self, tasks, port):
|
||||
"""
|
||||
Send splitwise tasks to specific port.
|
||||
|
||||
Parameters:
|
||||
tasks (list): List of tasks.
|
||||
port (int): Port number.
|
||||
|
||||
Returns:
|
||||
int: Current port number, -1 if tasks are not sent.
|
||||
"""
|
||||
current_port = -1
|
||||
if port not in self.connect_innode_instances:
|
||||
self.create_connection(port)
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[
|
||||
self.local_data_parallel_id
|
||||
]
|
||||
self.logger.info(f"send_splitwise_tasks_innode: port={port}, tasks={[task.request_id for task in tasks]}")
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
|
||||
for task in tasks:
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
|
||||
current_port = port
|
||||
return current_port
|
||||
|
||||
def send_first_token(self, prefill_msg, tasks_list):
|
||||
"""
|
||||
send first token to specific port
|
||||
Prefill send first token to specific port
|
||||
"""
|
||||
if not isinstance(tasks_list, list):
|
||||
tasks_list = [tasks_list]
|
||||
self.logger.info(f"send_first_token: send first token to decode, {[x.request_id for x in tasks_list]}")
|
||||
if prefill_msg["transfer_protocol"] == "ipc":
|
||||
port = prefill_msg["cache_info"]["ipc"]["port"]
|
||||
if port not in self.connect_innode_instances:
|
||||
self.create_connection(port)
|
||||
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
|
||||
else:
|
||||
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
|
||||
self.logger.info(f"send_first_token: send first token to port {node} decode")
|
||||
self._send_message(node, "decode", tasks_list)
|
||||
|
||||
def create_connection(self, port):
|
||||
"""
|
||||
Create a connection to specific port.
|
||||
|
||||
Parameters:
|
||||
port (int): Port number.
|
||||
"""
|
||||
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
|
||||
address = ("0.0.0.0", int(port))
|
||||
else:
|
||||
address = f"/dev/shm/fd_task_queue_{port}.sock"
|
||||
|
||||
self.connect_innode_instances[port] = EngineWorkerQueue(
|
||||
address=address,
|
||||
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||
client_id=0,
|
||||
addr = f"{prefill_msg['decode_ip']}:{prefill_msg['decode_connector_port']}"
|
||||
self.logger.info(
|
||||
f"send_first_token: send first token to decode ({addr}), {[x.request_id for x in tasks_list]}"
|
||||
)
|
||||
self._send_message(addr, "decode", tasks_list)
|
||||
|
||||
def check_decode_allocated(self, task):
|
||||
self.logger.debug(f"start check decode allocated: {task.request_id}")
|
||||
"""Check whether the requests have been allocated resources in decode."""
|
||||
self.logger.debug(f"check_decode_allocated: {task.request_id}")
|
||||
start_time = time.time()
|
||||
if task.disaggregate_info is None:
|
||||
return True, ""
|
||||
if self.enable_decode_cache_task:
|
||||
return True, ""
|
||||
if task.disaggregate_info["role"] != "prefill":
|
||||
return True, ""
|
||||
|
||||
while self.current_request_ids[task.request_id] == "init":
|
||||
time.sleep(0.001)
|
||||
if time.time() - start_time > envs.FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS:
|
||||
del self.current_request_ids[task.request_id]
|
||||
return False, "timeout"
|
||||
|
||||
msg = self.current_request_ids[task.request_id]
|
||||
del self.current_request_ids[task.request_id]
|
||||
if msg == "finished":
|
||||
return True, ""
|
||||
self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}")
|
||||
return False, msg
|
||||
else:
|
||||
self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}")
|
||||
return False, msg
|
||||
|
||||
def send_cache_info_to_messager(self, tasks: List[Request], current_id):
|
||||
"""
|
||||
@@ -308,13 +246,12 @@ class SplitwiseConnector:
|
||||
"need_prefill_tokens": tasks[i].need_prefill_tokens,
|
||||
}
|
||||
else:
|
||||
if current_id == -1:
|
||||
current_id = dsg_info["cache_info"]["ipc"]["current_id"]
|
||||
info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"src_block_ids": tasks[i].block_tables,
|
||||
"current_id": current_id,
|
||||
}
|
||||
info.update(dsg_info)
|
||||
cache_info.append(info)
|
||||
|
||||
self.logger.debug(f"send_cache_info_to_messager, {cache_info}")
|
||||
@@ -333,56 +270,29 @@ class SplitwiseConnector:
|
||||
if dsg_info is None:
|
||||
self.logger.debug(f"skip send_cache_infos_to_prefill, {tasks[i].request_id}")
|
||||
continue
|
||||
self.logger.debug(f"send_cache_infos_to_prefill, {dsg_info}")
|
||||
|
||||
if dsg_info["transfer_protocol"] == "ipc":
|
||||
if tasks[i].get("error_msg", None) is not None:
|
||||
info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"error_msg": tasks[i].get("error_msg"),
|
||||
}
|
||||
else:
|
||||
addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}"
|
||||
info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.parallel_config.device_ids.split(","),
|
||||
"transfer_protocol": "ipc",
|
||||
"dest_block_ids": dsg_info["block_tables"],
|
||||
}
|
||||
if dsg_info["cache_info"]["ipc"]["port"] not in cache_info:
|
||||
cache_info[dsg_info["cache_info"]["ipc"]["port"]] = []
|
||||
cache_info[dsg_info["cache_info"]["ipc"]["port"]].append(info)
|
||||
else:
|
||||
if tasks[i].get("error_msg", None) is not None:
|
||||
info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"error_msg": tasks[i].get("error_msg"),
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": [self.cfg.parallel_config.device_ids.split(",")[self.local_data_parallel_id]],
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": [
|
||||
self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"][self.local_data_parallel_id]
|
||||
],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": dsg_info["block_tables"],
|
||||
"decode_tp_size": self.cfg.parallel_config.tensor_parallel_size,
|
||||
}
|
||||
|
||||
addr = f"{dsg_info['cache_info']['rdma']['ip']}:" + f"{dsg_info['cache_info']['rdma']['port']}"
|
||||
if addr not in cache_info:
|
||||
cache_info[addr] = []
|
||||
cache_info[addr].append(info)
|
||||
|
||||
self.logger.debug(f"send cache info to prefill, {cache_info}")
|
||||
if len(cache_info):
|
||||
for k, v in cache_info.items():
|
||||
self.logger.info(f"{k} {v}")
|
||||
if ":" in str(k):
|
||||
self._send_message(k, "cache_sync", v)
|
||||
else:
|
||||
if k not in self.connect_innode_instances:
|
||||
self.create_connection(k)
|
||||
self.connect_innode_instances[k].put_cache_info(v)
|
||||
for key, info in cache_info.items():
|
||||
self._send_message(key, "cache_sync", info)
|
||||
|
||||
def _serialize_message(self, msg_type: str, payload) -> bytes:
|
||||
# TODO 压缩
|
||||
|
||||
if msg_type == "decode" or msg_type == "prefill":
|
||||
payload = [output.to_dict() for output in payload]
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ class TestConfig(unittest.TestCase):
|
||||
test_mode=True,
|
||||
)
|
||||
fd_config.init_cache_info()
|
||||
assert fd_config.disaggregate_info["role"] == "prefill"
|
||||
assert fd_config.register_info is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user