[Feature] [PD] add simple router and refine splitwise deployment (#4709)

* add simple router and refine splitwise deployment

* fix
This commit is contained in:
Juncai
2025-11-06 14:56:02 +08:00
committed by GitHub
parent 831266da7a
commit 08ca0f6aea
39 changed files with 2397 additions and 171 deletions

View File

@@ -1310,6 +1310,24 @@ class CacheConfig:
logger.info("=============================================================")
class RouterConfig:
"""
Configuration for router
Attributes:
router: the url of router, such as http://127.0.0.1:8000
api_server_host: the host ip of model server
api_server_port: the http port of model server
"""
def __init__(self, args: dict):
self.router = args["router"]
if self.router is not None and not self.router.startswith(("http://", "https://")):
self.router = f"http://{self.router}"
self.api_server_host = get_host_ip()
self.api_server_port = args["port"]
class CommitConfig:
"""
Configuration for tracking version information from version.txt
@@ -1411,6 +1429,7 @@ class FDConfig:
speculative_config: SpeculativeConfig = None,
eplb_config: EPLBConfig = None,
structured_outputs_config: StructuredOutputsConfig = None,
router_config: RouterConfig = None,
tokenizer: str = None,
ips: str = None,
use_warmup: bool = False,
@@ -1438,6 +1457,7 @@ class FDConfig:
self.cache_config: CacheConfig = cache_config # type: ignore
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
self.router_config: RouterConfig = router_config
# Initialize cuda graph capture list
max_capture_shape = self.scheduler_config.max_num_seqs
@@ -1517,6 +1537,7 @@ class FDConfig:
self.read_from_config()
self.postprocess()
self.init_cache_info()
if test_mode:
return
self.check()
@@ -1734,29 +1755,66 @@ class FDConfig:
"""
initialize cache info
"""
disaggregate_info = {}
# TODO: group the splitiwse params, remove code of v0
# v0 requires prefill and decode in one node and it uses local scheduler
# v1 supports prefill and decode in multi node and it uses splitwise or dp scheduler
# v2 supports prefill and decode in multi node and it uses router and local scheduler
self.splitwise_version = None
if self.scheduler_config.name == "local" and (self.router_config is None or self.router_config.router is None):
self.splitwise_version = "v0"
elif self.scheduler_config.name in ("splitwise", "dp"):
self.splitwise_version = "v1"
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
self.splitwise_version = "v2"
else:
raise ValueError(
f"Unsupported scheduler mode, scheduler_name: {self.scheduler_config.name}, "
f"router_config: {self.router_config}"
)
logger.info(f"splitwise_version: {self.splitwise_version}")
if isinstance(self.parallel_config.engine_worker_queue_port, (int, str)):
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port
else:
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
self.parallel_config.local_data_parallel_id
]
connector_port = self.cache_config.pd_comm_port[0] if self.cache_config.pd_comm_port else None
self.disaggregate_info = {}
if self.scheduler_config.splitwise_role != "mixed":
disaggregate_info["role"] = self.scheduler_config.splitwise_role
disaggregate_info["cache_info"] = dict()
self.disaggregate_info["role"] = self.scheduler_config.splitwise_role
self.disaggregate_info["cache_info"] = dict()
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
disaggregate_info["transfer_protocol"] = current_protocol
self.disaggregate_info["transfer_protocol"] = current_protocol
for protocol in current_protocol:
if protocol == "ipc":
disaggregate_info["cache_info"][protocol] = {
self.disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.parallel_config.engine_worker_queue_port[
self.parallel_config.local_data_parallel_id
],
"port": engine_worker_queue_port,
"device_ids": self.local_device_ids,
}
elif protocol == "rdma":
disaggregate_info["cache_info"][protocol] = {
self.disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.cache_config.pd_comm_port[0],
"port": connector_port,
"rdma_port": self.cache_config.rdma_comm_ports,
}
self.disaggregate_info = disaggregate_info
logger.info(f"disaggregate_info: {self.disaggregate_info}")
logger.info(f"disaggregate_info: {self.disaggregate_info}")
if self.router_config:
self.register_info = {
"role": self.scheduler_config.splitwise_role,
"host_ip": self.host_ip,
"port": self.router_config.api_server_port,
"connector_port": connector_port,
"rdma_ports": self.cache_config.rdma_comm_ports,
"engine_worker_queue_port": engine_worker_queue_port,
"device_ids": self.local_device_ids,
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
}
logger.info(f"register_info: {self.register_info}")
def read_from_config(self):
"""