mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] [PD] add simple router and refine splitwise deployment (#4709)
* add simple router and refine splitwise deployment * fix
This commit is contained in:
@@ -1310,6 +1310,24 @@ class CacheConfig:
|
||||
logger.info("=============================================================")
|
||||
|
||||
|
||||
class RouterConfig:
|
||||
"""
|
||||
Configuration for router
|
||||
Attributes:
|
||||
router: the url of router, such as http://127.0.0.1:8000
|
||||
api_server_host: the host ip of model server
|
||||
api_server_port: the http port of model server
|
||||
"""
|
||||
|
||||
def __init__(self, args: dict):
|
||||
self.router = args["router"]
|
||||
if self.router is not None and not self.router.startswith(("http://", "https://")):
|
||||
self.router = f"http://{self.router}"
|
||||
|
||||
self.api_server_host = get_host_ip()
|
||||
self.api_server_port = args["port"]
|
||||
|
||||
|
||||
class CommitConfig:
|
||||
"""
|
||||
Configuration for tracking version information from version.txt
|
||||
@@ -1411,6 +1429,7 @@ class FDConfig:
|
||||
speculative_config: SpeculativeConfig = None,
|
||||
eplb_config: EPLBConfig = None,
|
||||
structured_outputs_config: StructuredOutputsConfig = None,
|
||||
router_config: RouterConfig = None,
|
||||
tokenizer: str = None,
|
||||
ips: str = None,
|
||||
use_warmup: bool = False,
|
||||
@@ -1438,6 +1457,7 @@ class FDConfig:
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
|
||||
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
|
||||
self.router_config: RouterConfig = router_config
|
||||
|
||||
# Initialize cuda graph capture list
|
||||
max_capture_shape = self.scheduler_config.max_num_seqs
|
||||
@@ -1517,6 +1537,7 @@ class FDConfig:
|
||||
|
||||
self.read_from_config()
|
||||
self.postprocess()
|
||||
self.init_cache_info()
|
||||
if test_mode:
|
||||
return
|
||||
self.check()
|
||||
@@ -1734,29 +1755,66 @@ class FDConfig:
|
||||
"""
|
||||
initialize cache info
|
||||
"""
|
||||
disaggregate_info = {}
|
||||
# TODO: group the splitiwse params, remove code of v0
|
||||
# v0 requires prefill and decode in one node and it uses local scheduler
|
||||
# v1 supports prefill and decode in multi node and it uses splitwise or dp scheduler
|
||||
# v2 supports prefill and decode in multi node and it uses router and local scheduler
|
||||
self.splitwise_version = None
|
||||
if self.scheduler_config.name == "local" and (self.router_config is None or self.router_config.router is None):
|
||||
self.splitwise_version = "v0"
|
||||
elif self.scheduler_config.name in ("splitwise", "dp"):
|
||||
self.splitwise_version = "v1"
|
||||
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
|
||||
self.splitwise_version = "v2"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported scheduler mode, scheduler_name: {self.scheduler_config.name}, "
|
||||
f"router_config: {self.router_config}"
|
||||
)
|
||||
logger.info(f"splitwise_version: {self.splitwise_version}")
|
||||
|
||||
if isinstance(self.parallel_config.engine_worker_queue_port, (int, str)):
|
||||
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port
|
||||
else:
|
||||
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
]
|
||||
connector_port = self.cache_config.pd_comm_port[0] if self.cache_config.pd_comm_port else None
|
||||
|
||||
self.disaggregate_info = {}
|
||||
if self.scheduler_config.splitwise_role != "mixed":
|
||||
disaggregate_info["role"] = self.scheduler_config.splitwise_role
|
||||
disaggregate_info["cache_info"] = dict()
|
||||
self.disaggregate_info["role"] = self.scheduler_config.splitwise_role
|
||||
self.disaggregate_info["cache_info"] = dict()
|
||||
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
|
||||
disaggregate_info["transfer_protocol"] = current_protocol
|
||||
self.disaggregate_info["transfer_protocol"] = current_protocol
|
||||
|
||||
for protocol in current_protocol:
|
||||
if protocol == "ipc":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
self.disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.parallel_config.engine_worker_queue_port[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
],
|
||||
"port": engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
}
|
||||
elif protocol == "rdma":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
self.disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.cache_config.pd_comm_port[0],
|
||||
"port": connector_port,
|
||||
"rdma_port": self.cache_config.rdma_comm_ports,
|
||||
}
|
||||
self.disaggregate_info = disaggregate_info
|
||||
logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
|
||||
if self.router_config:
|
||||
self.register_info = {
|
||||
"role": self.scheduler_config.splitwise_role,
|
||||
"host_ip": self.host_ip,
|
||||
"port": self.router_config.api_server_port,
|
||||
"connector_port": connector_port,
|
||||
"rdma_ports": self.cache_config.rdma_comm_ports,
|
||||
"engine_worker_queue_port": engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
|
||||
}
|
||||
logger.info(f"register_info: {self.register_info}")
|
||||
|
||||
def read_from_config(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user