[PD Disaggregation] Unify the disaggregation info and the pd communication (#5438)

* Unify the disaggregation info and the pd communication

* up

* up

* fix

* fix conflict

* fix unittest
This commit is contained in:
Juncai
2025-12-09 14:44:59 +08:00
committed by GitHub
parent 8178e3fc6a
commit 83ea9646f9
10 changed files with 146 additions and 233 deletions

View File

@@ -96,6 +96,19 @@ def parse_args():
return args
def get_decode_ip_idx(task):
"""For compatibility, get decode ip and idx from task"""
if "decode_ip" in task:
decode_ip = task["decode_ip"]
else:
decode_ip = task["ip"]
if "decode_rdma_ports" in task:
decode_rdma_ports = task["decode_rdma_ports"]
else:
decode_rdma_ports = task["rdma_ports"]
return decode_ip, decode_rdma_ports
class CacheMessager:
"""
CacheMessager is used to send the cache data between the engine worker and the cache server.
@@ -282,6 +295,7 @@ class CacheMessager:
self.cache_info[info["request_id"]] = current_info
else:
self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0]
if prefilled_layer_idx == self.num_layers - 1:
@@ -316,15 +330,18 @@ class CacheMessager:
continue
current_transfer_protocol = item["transfer_protocol"]
if item["transfer_protocol"] == "rdma":
target_ip = item["ip"]
target_id = int(item["rdma_ports"][self.rank])
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
decode_ip, decode_rdma_ports = get_decode_ip_idx(item)
decode_idx = int(decode_rdma_ports[self.rank])
status = self.messager[current_transfer_protocol].connect(decode_ip, decode_idx)
if not status:
logger.error(f"connect to {target_ip}:{target_id} failed")
logger.error(f"connect to {decode_ip}:{decode_idx} failed")
item["status"] = "connect error"
elif item["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(item["device_ids"][self.rank])
decode_ip = "0.0.0.0"
decode_device_ids = (
item["decode_device_ids"] if "decode_device_ids" in item else item["device_ids"]
)
decode_idx = int(decode_device_ids[self.rank])
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
if item["current_id"] < prefilled_step_idx:
@@ -335,8 +352,8 @@ class CacheMessager:
for layer_idx in range(item["layer_idx"], current_layer_idx):
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
decode_ip,
decode_idx,
src_block_ids,
dest_block_ids,
layer_idx,
@@ -345,7 +362,7 @@ class CacheMessager:
item["status"] = "write cache error"
logger.error(
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
f"req_id: {item['request_id']}, dest_ip: {decode_ip}"
)
break
@@ -365,7 +382,7 @@ class CacheMessager:
if "error" not in item["status"]:
item["status"] = "finished"
if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
self.messager["ipc"].write_block_by_sync(decode_idx)
logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[item["request_id"], item["status"]]])
@@ -387,8 +404,9 @@ class CacheMessager:
self.engine_worker_queue.connect_task_barrier.wait()
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
status = self.messager["rdma"].connect(ip, rdma_port)
decode_ip, decode_rdma_ports = get_decode_ip_idx(task)
rdma_port = decode_rdma_ports[self.rank]
status = self.messager["rdma"].connect(decode_ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
@@ -634,6 +652,7 @@ class CacheMessagerV1:
end_layer_idx = prefilled_layer_idx
if sended_layer_idx == prefilled_layer_idx: # computation not in next layer
time.sleep(0.01)
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
engine_index = batch_engine_signals[i][0]
@@ -650,13 +669,13 @@ class CacheMessagerV1:
else:
current_transfer_protocol = task["transfer_protocol"]
if task["transfer_protocol"] == "rdma":
target_ip = task["ip"]
decode_ip, decode_rdma_ports = get_decode_ip_idx(task)
# Default decode_tp_size to prefill tp_size (self.nranks) if not specified
decode_tp_size = task.get("decode_tp_size", self.nranks)
if len(task["rdma_ports"]) == self.nranks:
target_id = int(task["rdma_ports"][self.rank])
elif len(task["rdma_ports"]) == 1:
target_id = task["rdma_ports"][0]
if len(decode_rdma_ports) == self.nranks:
decode_idx = int(decode_rdma_ports[self.rank])
elif len(decode_rdma_ports) == 1:
decode_idx = decode_rdma_ports[0]
else:
task["status"] = "the tp_size of prefill and decode is mismatch"
continue
@@ -666,21 +685,26 @@ class CacheMessagerV1:
# TODO: use is connected to check if the connection is still alive
logger.debug(
f"rdma, start connect decode, {target_ip}:{target_id}, "
f"rdma, start connect decode, {decode_ip}:{decode_idx}, "
f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}"
)
status = self.messager[current_transfer_protocol].connect(
target_ip, target_id, decode_tp_size
decode_ip, decode_idx, decode_tp_size
)
if status:
logger.debug(f"connect to {target_ip}:{target_id} success")
logger.debug(f"connect to {decode_ip}:{decode_idx} success")
else:
logger.error(f"connect to {target_ip}:{target_id} failed")
logger.error(f"connect to {decode_ip}:{decode_idx} failed")
task["status"] = "connection error"
continue
elif task["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(task["device_ids"][self.rank])
decode_device_ids = (
task["decode_device_ids"]
if "decode_device_ids" in task
else task["device_ids"]
)
decode_ip = "0.0.0.0"
decode_idx = int(decode_device_ids[self.rank])
src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
@@ -688,12 +712,12 @@ class CacheMessagerV1:
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
logger.info(
f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
f"start write cache for a layer, {req_id}, {layer_idx}, {decode_ip}, {decode_idx}, block_id_start {block_id_start} block_id_end {block_id_end}"
)
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
decode_ip,
decode_idx,
src_block_ids,
dest_block_ids,
layer_idx,
@@ -701,7 +725,7 @@ class CacheMessagerV1:
if return_code != 0:
task["status"] = "write cache error"
logger.error(
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {decode_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
)
tok = time.time()
cost_time = tok - tic
@@ -709,7 +733,7 @@ class CacheMessagerV1:
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
f"finish write cache for a layer, {req_id}, {layer_idx}, {decode_ip}, {decode_idx},"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
)
@@ -734,8 +758,13 @@ class CacheMessagerV1:
task = self.idx_cache_task_dict[engine_idx]
if task["status"] == "finished" or ("error" in task["status"]):
if task["transfer_protocol"] == "ipc":
target_id = int(task["device_ids"][self.rank])
self.messager["ipc"].write_block_by_sync(target_id)
decode_device_ids = (
task["decode_device_ids"]
if "decode_device_ids" in task
else task["device_ids"]
)
decode_idx = int(decode_device_ids[self.rank])
self.messager["ipc"].write_block_by_sync(decode_idx)
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
logger.info(
@@ -796,18 +825,17 @@ class CacheMessagerV1:
self.engine_worker_queue.connect_task_barrier.wait()
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip = task["ip"]
decode_ip, decode_rdma_ports = get_decode_ip_idx(task)
# Default decode_tp_size to self.nranks (number of ranks) if not specified in the task.
decode_tp_size = task.get("decode_tp_size", self.nranks)
rdma_ports = task["rdma_ports"]
rdma_ports_len = len(rdma_ports)
rdma_ports_len = len(decode_rdma_ports)
if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks):
# TODO: support other cases
logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}")
response = {"task_id": task_id, "success": False}
else:
port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank]
status = self.messager["rdma"].connect(ip, port, decode_tp_size)
port = decode_rdma_ports[0] if rdma_ports_len == 1 else decode_rdma_ports[self.rank]
status = self.messager["rdma"].connect(decode_ip, port, decode_tp_size)
if not status:
response = {"task_id": task_id, "success": False}
else:

View File

@@ -1919,42 +1919,23 @@ class FDConfig:
else None
)
self.disaggregate_info = {}
if self.scheduler_config.splitwise_role != "mixed":
self.disaggregate_info["role"] = self.scheduler_config.splitwise_role
self.disaggregate_info["cache_info"] = dict()
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
self.disaggregate_info["transfer_protocol"] = current_protocol
for protocol in current_protocol:
if protocol == "ipc":
self.disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": engine_worker_queue_port,
"device_ids": self.local_device_ids,
}
elif protocol == "rdma":
self.disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": connector_port,
"rdma_port": self.cache_config.rdma_comm_ports,
}
logger.info(f"disaggregate_info: {self.disaggregate_info}")
if self.router_config:
# the information for registering this server to router
self.register_info = {
"role": self.scheduler_config.splitwise_role,
"host_ip": self.host_ip,
"port": self.router_config.api_server_port,
"connector_port": connector_port,
"rdma_ports": self.cache_config.rdma_comm_ports,
"engine_worker_queue_port": engine_worker_queue_port,
"device_ids": self.local_device_ids,
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
"tp_size": self.parallel_config.tensor_parallel_size,
}
logger.info(f"register_info: {self.register_info}")
# the information for registering this server to router or splitwise_scheduler
port = self.router_config.api_server_port if self.router_config else None
transfer_protocol = (
self.cache_config.cache_transfer_protocol.split(",") if self.cache_config.cache_transfer_protocol else []
)
self.register_info = {
"role": self.scheduler_config.splitwise_role,
"host_ip": self.host_ip,
"port": port,
"connector_port": connector_port,
"rdma_ports": self.cache_config.rdma_comm_ports,
"engine_worker_queue_port": engine_worker_queue_port,
"device_ids": self.local_device_ids,
"transfer_protocol": transfer_protocol,
"tp_size": self.parallel_config.tensor_parallel_size,
}
logger.info(f"register_info: {self.register_info}")
def read_from_config(self):
"""

View File

@@ -424,7 +424,7 @@ class EngineService:
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.cfg.scheduler_config.splitwise_role == "prefill":
status, msg = self.split_connector.check_decode_allocated(task)
if status:
task.metrics.ask_decode_resource_finish_time = time.time()
@@ -469,7 +469,7 @@ class EngineService:
is_prefill = False
for i in range(len(tasks)):
if tasks[i].disaggregate_info is not None:
if tasks[i].disaggregate_info["role"] == "decode":
if self.cfg.scheduler_config.splitwise_role == "decode":
is_decode = True
else:
is_prefill = True
@@ -811,11 +811,10 @@ class EngineService:
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:
# start async preprocess
self.resource_manager.apply_async_preprocess(task)
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:
# start async preprocess
self.resource_manager.apply_async_preprocess(task)
need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
for task in tasks:
@@ -873,7 +872,6 @@ class EngineService:
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
if self.cfg.scheduler_config.splitwise_role == "prefill":
# to send cache info to cache messager
if tasks:
need_check_req_ids = [task.request_id for task in tasks]
@@ -912,6 +910,7 @@ class EngineService:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:
@@ -1765,11 +1764,10 @@ class EngineService:
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
request_queues_for_dp_ipc = None
result_queue_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
self.scheduler.start(role, host_ip, self.cfg.register_info)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()

View File

@@ -715,11 +715,10 @@ class LLMEngine:
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
request_queues_for_dp_ipc = None
result_queues_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.start(role, host_ip, disaggregate)
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queues_for_dp_ipc = []

View File

@@ -113,8 +113,7 @@ class ExpertService:
self.cfg.init_cache_info()
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
self.engine.scheduler.start(role, host_ip, disaggregate)
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.splitwise_receive_thread = threading.Thread(

View File

@@ -188,26 +188,15 @@ class Router:
is_same_tp_size = prefill_server.tp_size == decode_server.tp_size
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
cache_info = {}
if use_ipc:
cache_info["ipc"] = {
"ip": decode_server.host_ip,
"port": decode_server.engine_worker_queue_port,
"device_ids": decode_server.device_ids,
}
else:
cache_info["rdma"] = {
"ip": decode_server.host_ip,
"port": decode_server.connector_port,
"rdma_port": decode_server.rdma_ports,
}
disaggregate_info = {
"prefill": prefill_server.to_dict(),
"decode": decode_server.to_dict(),
"role": "decode",
"cache_info": cache_info,
"prefill_ip": prefill_server.host_ip,
"decode_ip": decode_server.host_ip,
"prefill_connector_port": prefill_server.connector_port,
"decode_connector_port": decode_server.connector_port,
"decode_device_ids": decode_server.device_ids,
"decode_rdma_ports": decode_server.rdma_ports,
"transfer_protocol": "ipc" if use_ipc else "rdma",
"decode_tp_size": decode_server.tp_size,
}
modified_request = request_data.copy()

View File

@@ -14,7 +14,6 @@
# limitations under the License.
"""
import copy
import hashlib
import math
import pickle
@@ -533,16 +532,26 @@ class APIScheduler:
else:
dnodes.sort()
dnode = self.select_pd(req, dnodes, "decode")
disaggregated = copy.deepcopy(dnode.disaggregated)
transfer_protocol = disaggregated["transfer_protocol"]
if len(transfer_protocol) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
if pnode.host == dnode.host:
disaggregated["transfer_protocol"] = "ipc"
else:
disaggregated["transfer_protocol"] = "rdma"
else:
disaggregated["transfer_protocol"] = transfer_protocol[0]
req.disaggregate_info = disaggregated
is_same_node = pnode.disaggregated["host_ip"] == dnode.disaggregated["host_ip"]
is_support_ipc = (
"ipc" in pnode.disaggregated["transfer_protocol"] and "ipc" in dnode.disaggregated["transfer_protocol"]
)
is_same_tp_size = pnode.disaggregated["tp_size"] == dnode.disaggregated["tp_size"]
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
disaggregate_info = {
"prefill_ip": pnode.disaggregated["host_ip"],
"decode_ip": dnode.disaggregated["host_ip"],
"prefill_connector_port": pnode.disaggregated["connector_port"],
"decode_connector_port": dnode.disaggregated["connector_port"],
"decode_device_ids": dnode.disaggregated["device_ids"],
"decode_rdma_ports": dnode.disaggregated["rdma_ports"],
"transfer_protocol": "ipc" if use_ipc else "rdma",
"decode_tp_size": dnode.disaggregated["tp_size"],
}
req.disaggregate_info = disaggregate_info
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
req_dict = req.to_dict()
req_dict["group"] = group

View File

@@ -24,7 +24,6 @@ import zmq
from fastdeploy import envs
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
@@ -53,7 +52,6 @@ class SplitwiseConnector:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
self.engine_worker_queue = worker_queue
self.resource_manager = resource_manager
self.connect_innode_instances = {}
self.current_request_ids = dict()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
@@ -172,119 +170,59 @@ class SplitwiseConnector:
def send_splitwise_tasks(self, tasks: List[Request], current_id):
"""
Send splitwise tasks to all connected addresses.
Prefill send splitwise tasks to decode.
Parameters:
tasks (list): List of tasks.
current_id (int): Current ID.
"""
addr = None
decode_diagg = None
for task in tasks:
if task.disaggregate_info is None:
continue
if task.disaggregate_info["transfer_protocol"] == "ipc":
addr = task.disaggregate_info["cache_info"]["ipc"]["port"]
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
self.logger.info(f"send_splitwise_tasks: protocol=ipc, addr={addr}, task={task.request_id}")
self.send_splitwise_tasks_innode([task], addr)
else:
self.current_request_ids[task.request_id] = "init"
task.disaggregate_info["role"] = "decode"
addr = f"{task.disaggregate_info['decode_ip']}:{task.disaggregate_info['decode_connector_port']}"
self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}")
self._send_message(addr, "prefill", [task])
addr = (
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
)
self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
task.disaggregate_info["role"] = "decode"
self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}")
self._send_message(addr, "prefill", [task])
task.disaggregate_info["cache_info"] = decode_diagg
task.disaggregate_info["role"] = "prefill"
def send_splitwise_tasks_innode(self, tasks, port):
"""
Send splitwise tasks to specific port.
Parameters:
tasks (list): List of tasks.
port (int): Port number.
Returns:
int: Current port number, -1 if tasks are not sent.
"""
current_port = -1
if port not in self.connect_innode_instances:
self.create_connection(port)
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[
self.local_data_parallel_id
]
self.logger.info(f"send_splitwise_tasks_innode: port={port}, tasks={[task.request_id for task in tasks]}")
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
current_port = port
return current_port
def send_first_token(self, prefill_msg, tasks_list):
"""
send first token to specific port
Prefill send first token to specific port
"""
if not isinstance(tasks_list, list):
tasks_list = [tasks_list]
self.logger.info(f"send_first_token: send first token to decode, {[x.request_id for x in tasks_list]}")
if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances:
self.create_connection(port)
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
self.logger.info(f"send_first_token: send first token to port {node} decode")
self._send_message(node, "decode", tasks_list)
def create_connection(self, port):
"""
Create a connection to specific port.
Parameters:
port (int): Port number.
"""
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = ("0.0.0.0", int(port))
else:
address = f"/dev/shm/fd_task_queue_{port}.sock"
self.connect_innode_instances[port] = EngineWorkerQueue(
address=address,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=0,
addr = f"{prefill_msg['decode_ip']}:{prefill_msg['decode_connector_port']}"
self.logger.info(
f"send_first_token: send first token to decode ({addr}), {[x.request_id for x in tasks_list]}"
)
self._send_message(addr, "decode", tasks_list)
def check_decode_allocated(self, task):
self.logger.debug(f"start check decode allocated: {task.request_id}")
"""Check whether the requests have been allocated resources in decode."""
self.logger.debug(f"check_decode_allocated: {task.request_id}")
start_time = time.time()
if task.disaggregate_info is None:
return True, ""
if self.enable_decode_cache_task:
return True, ""
if task.disaggregate_info["role"] != "prefill":
return True, ""
while self.current_request_ids[task.request_id] == "init":
time.sleep(0.001)
if time.time() - start_time > envs.FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS:
del self.current_request_ids[task.request_id]
return False, "timeout"
msg = self.current_request_ids[task.request_id]
del self.current_request_ids[task.request_id]
if msg == "finished":
return True, ""
self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}")
return False, msg
else:
self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_info_to_messager(self, tasks: List[Request], current_id):
"""
@@ -308,13 +246,12 @@ class SplitwiseConnector:
"need_prefill_tokens": tasks[i].need_prefill_tokens,
}
else:
if current_id == -1:
current_id = dsg_info["cache_info"]["ipc"]["current_id"]
info = {
"request_id": tasks[i].request_id,
"src_block_ids": tasks[i].block_tables,
"current_id": current_id,
}
info.update(dsg_info)
cache_info.append(info)
self.logger.debug(f"send_cache_info_to_messager, {cache_info}")
@@ -333,56 +270,29 @@ class SplitwiseConnector:
if dsg_info is None:
self.logger.debug(f"skip send_cache_infos_to_prefill, {tasks[i].request_id}")
continue
self.logger.debug(f"send_cache_infos_to_prefill, {dsg_info}")
if dsg_info["transfer_protocol"] == "ipc":
if tasks[i].get("error_msg", None) is not None:
info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}"
info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.parallel_config.device_ids.split(","),
"transfer_protocol": "ipc",
"dest_block_ids": dsg_info["block_tables"],
}
if dsg_info["cache_info"]["ipc"]["port"] not in cache_info:
cache_info[dsg_info["cache_info"]["ipc"]["port"]] = []
cache_info[dsg_info["cache_info"]["ipc"]["port"]].append(info)
else:
if tasks[i].get("error_msg", None) is not None:
info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
info = {
"request_id": tasks[i].request_id,
"device_ids": [self.cfg.parallel_config.device_ids.split(",")[self.local_data_parallel_id]],
"ip": self.cfg.host_ip,
"rdma_ports": [
self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"][self.local_data_parallel_id]
],
"transfer_protocol": "rdma",
"dest_block_ids": dsg_info["block_tables"],
"decode_tp_size": self.cfg.parallel_config.tensor_parallel_size,
}
addr = f"{dsg_info['cache_info']['rdma']['ip']}:" + f"{dsg_info['cache_info']['rdma']['port']}"
if addr not in cache_info:
cache_info[addr] = []
cache_info[addr].append(info)
self.logger.debug(f"send cache info to prefill, {cache_info}")
if len(cache_info):
for k, v in cache_info.items():
self.logger.info(f"{k} {v}")
if ":" in str(k):
self._send_message(k, "cache_sync", v)
else:
if k not in self.connect_innode_instances:
self.create_connection(k)
self.connect_innode_instances[k].put_cache_info(v)
for key, info in cache_info.items():
self._send_message(key, "cache_sync", info)
def _serialize_message(self, msg_type: str, payload) -> bytes:
# TODO 压缩
if msg_type == "decode" or msg_type == "prefill":
payload = [output.to_dict() for output in payload]

View File

@@ -130,7 +130,7 @@ class TestConfig(unittest.TestCase):
test_mode=True,
)
fd_config.init_cache_info()
assert fd_config.disaggregate_info["role"] == "prefill"
assert fd_config.register_info is not None
if __name__ == "__main__":