""" Async Router server for FastDeploy. Handles client requests and manages prefill/decode/mixed instances. This module references the router implementation of slglang and vllm. """ import asyncio import random from dataclasses import dataclass from itertools import chain from uuid import uuid4 import aiohttp import uvicorn from fastapi import FastAPI, HTTPException from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastdeploy.router.utils import ( InstanceInfo, InstanceRole, check_service_health_async, ) from fastdeploy.utils import FlexibleArgumentParser from fastdeploy.utils import router_logger as logger app = FastAPI() @dataclass class RouterArgs: host: str = "0.0.0.0" """ Host address to bind the router server """ port: int = 9000 """ Port to bind the router server. """ splitwise: bool = False """ Router uses splitwise deployment """ request_timeout_secs: int = 1800 """ Request timeout in seconds """ @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--host", type=str, default=RouterArgs.host, help="Host address to bind the router server.", ) parser.add_argument( "--port", type=int, default=RouterArgs.port, help="Port number to bind the router server", ) parser.add_argument( "--splitwise", action="store_true", default=RouterArgs.splitwise, help="Router uses splitwise deployment", ) parser.add_argument( "--request-timeout-secs", type=int, default=RouterArgs.request_timeout_secs, help="Request timeout in seconds", ) return parser class Router: """ Router class that handles requests from client and collects prefill/decode instance information """ def __init__(self, args): self.args = args self.host = args.host self.port = args.port self.splitwise = args.splitwise self.timeout = args.request_timeout_secs self.mixed_servers = [] self.prefill_servers = [] self.decode_servers = [] self.lock = asyncio.Lock() # async-safe lock logger.info("Router started at http://{}:{}".format(self.host, self.port)) async def register_instance(self, instance_info_dict: dict): """Register an instance asynchronously""" try: inst_info = InstanceInfo.from_dict(instance_info_dict) except Exception as e: logger.error(f"register instance failed: {e}") raise if (self.splitwise and inst_info.role == InstanceRole.MIXED) or ( not self.splitwise and inst_info.role != InstanceRole.MIXED ): raise ValueError(f"Invalid instance role: {inst_info.role}, splitwise: {self.splitwise}") if not await check_service_health_async(inst_info.url()): raise RuntimeError(f"Instance {inst_info} is not healthy") async with self.lock: if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers: self.mixed_servers.append(inst_info) logger.info( f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}" ) elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers: self.prefill_servers.append(inst_info) logger.info( f"Register prefill instance success: {inst_info}, " f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}" ) elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers: self.decode_servers.append(inst_info) logger.info( f"Register decode instance success: {inst_info}, " f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}" ) async def registered_number(self): """Get number of registered instances""" return { "mixed": len(self.mixed_servers), "prefill": len(self.prefill_servers), "decode": len(self.decode_servers), } async def select_pd(self): """Select one prefill and one decode server""" async with self.lock: if not self.prefill_servers: raise RuntimeError("No prefill servers available") if not self.decode_servers: raise RuntimeError("No decode servers available") pidx = random.randint(0, len(self.prefill_servers) - 1) didx = random.randint(0, len(self.decode_servers) - 1) return self.prefill_servers[pidx], self.decode_servers[didx] async def select_mixed(self): """Select one mixed server""" async with self.lock: if not self.mixed_servers: raise RuntimeError("No mixed servers available") idx = random.randint(0, len(self.mixed_servers) - 1) return self.mixed_servers[idx] async def handle_request(self, request_data: dict, endpoint_name: str): if self.splitwise: return await self.handle_splitwise_request(request_data, endpoint_name) else: return await self.handle_mixed_request(request_data, endpoint_name) async def handle_mixed_request(self, request_data: dict, endpoint_name: str): logger.debug(f"Received request: {request_data}") mixed_server = await self.select_mixed() if request_data.get("stream", False): return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name) else: return await self._generate(request_data, [mixed_server.url()], endpoint=endpoint_name) async def handle_splitwise_request(self, request_data: dict, endpoint_name: str): logger.debug(f"Received request: {request_data}") prefill_server, decode_server = await self.select_pd() logger.debug(f"Selected prefill server: {prefill_server}") logger.debug(f"Selected decode server: {decode_server}") if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1: raise HTTPException( status_code=400, detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1", ) # TODO: unify the disaggregate_info in server and remove redundancy params is_same_node = prefill_server.host_ip == decode_server.host_ip is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol is_same_tp_size = prefill_server.tp_size == decode_server.tp_size use_ipc = is_same_node and is_support_ipc and is_same_tp_size disaggregate_info = { "prefill_ip": prefill_server.host_ip, "decode_ip": decode_server.host_ip, "prefill_connector_port": prefill_server.connector_port, "decode_connector_port": decode_server.connector_port, "decode_device_ids": decode_server.device_ids, "decode_rdma_ports": decode_server.rdma_ports, "transfer_protocol": "ipc" if use_ipc else "rdma", "decode_tp_size": decode_server.tp_size, } modified_request = request_data.copy() modified_request["disaggregate_info"] = disaggregate_info if "request_id" not in modified_request: modified_request["request_id"] = str(uuid4()) logger.debug(f"Modified request: {modified_request}") if request_data.get("stream", False): return await self._generate_stream( modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name ) else: return await self._generate( modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name ) async def _generate( self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" ) -> ORJSONResponse: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls] results = await asyncio.gather(*tasks) ret_json = await results[return_result_url_index].json() return ORJSONResponse(content=ret_json, status_code=results[return_result_url_index].status) async def _generate_stream( self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" ): async def stream_results(): async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls] results = await asyncio.gather(*tasks) AIOHTTP_STREAM_READ_CHUNK_SIZE = 1024 * 64 # prevent aiohttp's "Chunk too big" error async for chunk in results[return_result_url_index].content.iter_chunked( AIOHTTP_STREAM_READ_CHUNK_SIZE ): logger.debug(f"receive response chunk: {chunk}") yield chunk return StreamingResponse(stream_results(), media_type="text/event-stream") async def monitor_instance_health(self, interval_secs: float = 5.0): """ Continuously check the health of prefill, decode, and mixed instances and remove unhealthy ones. """ while True: try: prefill_to_remove = [] decode_to_remove = [] mixed_to_remove = [] async with aiohttp.ClientSession() as session: # check servers prefill_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.prefill_servers] decode_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.decode_servers] mixed_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.mixed_servers] # gather all tasks concurrently all_tasks = prefill_tasks + decode_tasks + mixed_tasks for inst, coro in all_tasks: try: resp = await coro if resp.status != 200: logger.warning(f"Instance {inst.url()} unhealthy: {resp.status}") if inst in self.prefill_servers: prefill_to_remove.append(inst) elif inst in self.decode_servers: decode_to_remove.append(inst) elif inst in self.mixed_servers: mixed_to_remove.append(inst) except Exception as e: logger.warning(f"Instance {inst.url()} check failed: {e}") if inst in self.prefill_servers: prefill_to_remove.append(inst) elif inst in self.decode_servers: decode_to_remove.append(inst) elif inst in self.mixed_servers: mixed_to_remove.append(inst) # remove unhealthy instances under lock async with self.lock: if prefill_to_remove: for inst in prefill_to_remove: self.prefill_servers.remove(inst) logger.info(f"Removed unhealthy prefill instance: {inst.url()}") if decode_to_remove: for inst in decode_to_remove: self.decode_servers.remove(inst) logger.info(f"Removed unhealthy decode instance: {inst.url()}") if mixed_to_remove: for inst in mixed_to_remove: self.mixed_servers.remove(inst) logger.info(f"Removed unhealthy mixed instance: {inst.url()}") await asyncio.sleep(interval_secs) prefill_instances = [inst.url() for inst in self.prefill_servers] decode_instances = [inst.url() for inst in self.decode_servers] mixed_instance = [inst.url() for inst in self.mixed_servers] logger.debug( f"Healthy prefill instances: {prefill_instances}, " f"Healthy decode instances: {decode_instances}, " f"Healthy mixed instance: {mixed_instance}" ) except Exception as e: logger.exception(f"Failed to monitor instance health: {e}") @app.post("/register") async def register(instance_info_dict: dict): """Register prefill/decode/mixed servers""" try: await app.state.router.register_instance(instance_info_dict) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) return {"status": "success"} @app.get("/registered_number") async def registered_number(): """Get the number of registered prefill/decode/mixed servers""" return await app.state.router.registered_number() @app.post("/v1/chat/completions") async def create_chat_completion(request_data: dict): return await app.state.router.handle_request(request_data, "v1/chat/completions") @app.post("/v1/completions") async def create_completion(request_data: dict): return await app.state.router.handle_request(request_data, "v1/completions") @app.get("/health") async def health_check(): """Basic health check""" return Response(status_code=200) @app.get("/health_generate") async def health_generate(): """Check all prefill and decode servers are healthy""" router = app.state.router async with aiohttp.ClientSession() as session: tasks = [session.get(f"{s.url()}/health") for s in chain(router.prefill_servers, router.decode_servers)] for coro in asyncio.as_completed(tasks): resp = await coro if resp.status != 200: logger.warning(f"Server {resp.url} not healthy: {resp.status}") return Response(status_code=200) def launch_router(router_args: RouterArgs): app.state.router_args = router_args print(f"Starting router with args: {router_args}") @app.on_event("startup") async def startup_event(): app.state.router = Router(app.state.router_args) asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5)) uvicorn.run(app, host=router_args.host, port=int(router_args.port))