diff --git a/fastdeploy/entrypoints/openai/multi_api_server.py b/fastdeploy/entrypoints/openai/multi_api_server.py index 358f0f8f0..16a61de94 100644 --- a/fastdeploy/entrypoints/openai/multi_api_server.py +++ b/fastdeploy/entrypoints/openai/multi_api_server.py @@ -25,20 +25,30 @@ from fastdeploy.utils import get_logger, is_port_available logger = get_logger("multi_api_server", "multi_api_server.log") -def start_servers(server_count, server_args, ports, metrics_ports): +def start_servers(server_count, server_args, ports, metrics_ports, controller_ports): processes = [] logger.info(f"Starting servers on ports: {ports} with args: {server_args} and metrics ports: {metrics_ports}") for i in range(len(server_args)): if server_args[i] == "--engine-worker-queue-port": engine_worker_queue_port = server_args[i + 1].split(",") break - check_param(ports, server_count) - check_param(metrics_ports, server_count) - check_param(engine_worker_queue_port, server_count) + if not check_param(ports, server_count): + return + if not check_param(metrics_ports, server_count): + return + if not check_param(engine_worker_queue_port, server_count): + return + if controller_ports != "-1": + controller_ports = controller_ports.split(",") + if not check_param(controller_ports, server_count): + return + else: + controller_ports = [-1] * server_count # check_param(server_args, server_count) for i in range(server_count): port = int(ports[i]) metrics_port = int(metrics_ports[i]) + controller_port = int(controller_ports[i]) env = os.environ.copy() env["FD_LOG_DIR"] = f"log_{i}" @@ -51,6 +61,8 @@ def start_servers(server_count, server_args, ports, metrics_ports): str(port), "--metrics-port", str(metrics_port), + "--controller-port", + str(controller_port), "--local-data-parallel-id", str(i), ] @@ -69,7 +81,8 @@ def check_param(ports, num_servers): for port in ports: logger.info(f"check port {port}") if not is_port_available("0.0.0.0", int(port)): - raise ValueError(f"Port {port} is already in use.") + return False + return True def main(): @@ -77,6 +90,7 @@ def main(): parser.add_argument("--ports", default="8000,8002", type=str, help="ports to the http server") parser.add_argument("--num-servers", default=2, type=int, help="number of workers") parser.add_argument("--metrics-ports", default="8800,8802", type=str, help="ports for metrics server") + parser.add_argument("--controller-ports", default="-1", type=str, help="ports for controller server port") parser.add_argument("--args", nargs=argparse.REMAINDER, help="remaining arguments are passed to api_server.py") args = parser.parse_args() @@ -90,6 +104,7 @@ def main(): server_args=args.args, ports=args.ports.split(","), metrics_ports=args.metrics_ports.split(","), + controller_ports=args.controller_ports, ) try: