[Feature] [PD] add simple router and refine splitwise deployment (#4709)

* add simple router and refine splitwise deployment

* fix
This commit is contained in:
Juncai
2025-11-06 14:56:02 +08:00
committed by GitHub
parent 831266da7a
commit 08ca0f6aea
39 changed files with 2397 additions and 171 deletions

View File

@@ -23,10 +23,11 @@ import time
import traceback
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import paddle
import requests
import zmq
from opentelemetry import trace
@@ -45,6 +46,7 @@ from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.router.utils import check_service_health
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import (
@@ -95,6 +97,7 @@ class EngineService:
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.llm_logger.info("Use V1 KVCache Scheduler")
self.resource_manager = ResourceManagerV1(
cfg.scheduler_config.max_num_seqs,
cfg,
@@ -103,6 +106,7 @@ class EngineService:
cfg.parallel_config.local_data_parallel_id,
)
else:
self.llm_logger.info("Use V0 KVCache Scheduler")
self.resource_manager = ResourceManager(
cfg.scheduler_config.max_num_seqs,
cfg,
@@ -118,7 +122,6 @@ class EngineService:
]
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
self.waiting_requests = []
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
@@ -149,14 +152,18 @@ class EngineService:
def start(self):
self.running = True
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
self.insert_task_to_worker_thread = threading.Thread(
target=self._schedule_request_to_worker_v1, daemon=True
)
else:
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.split_mode_get_tasks()
self._process_splitwise_task()
self._register_to_router()
def create_data_processor(self):
self.input_processor = InputPreprocessor(
@@ -313,7 +320,7 @@ class EngineService:
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
)
def insert_tasks(self, tasks, current_id=-1, allocated=False):
def insert_tasks(self, tasks: Union[List[Request], List[RequestOutput]], current_id=-1, allocated=False):
"""
Insert tasks to engine.
"""
@@ -358,6 +365,7 @@ class EngineService:
current_tasks.append(cur_task)
if current_tasks:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
self.llm_logger.debug(f"put task to engine worker queue, task:{current_tasks}")
return True
self.resource_manager.check_and_free_block_tables()
@@ -574,7 +582,7 @@ class EngineService:
patch_st += chunk_patch_num
request.set("prefill_chunk_info", chunks_info)
def _insert_task_to_worker(self):
def _schedule_request_to_worker(self):
"""
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
@@ -619,9 +627,12 @@ class EngineService:
if len(tasks) == 0:
time.sleep(0.001)
continue
if self.cfg.splitwise_version == "v2" and self.cfg.scheduler_config.splitwise_role == "decode":
# the task in decode instance will processed in _process_splitwise_task thread
continue
llm_logger.debug(f"get tasks from scheduler: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
insert_successful = self.insert_tasks(tasks, current_id)
@@ -636,7 +647,7 @@ class EngineService:
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
self.llm_logger.error(err_msg)
def _scheduler_task_to_worker_v1(self):
def _schedule_request_to_worker_v1(self):
"""
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
"""
@@ -664,6 +675,7 @@ class EngineService:
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
@@ -822,6 +834,7 @@ class EngineService:
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if self.cfg.scheduler_config.splitwise_role == "decode":
return
while self.running:
try:
block = True if len(added_requests) == 0 else False
@@ -975,17 +988,38 @@ class EngineService:
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def split_mode_get_tasks(self):
def _process_splitwise_task(self):
"""
Split mode get tasks
Processing tasks from engine worker queue in splitwise deployment.
For v0 version, prefill instance gets tasks from engine worker queue.
For v1 and v2 version, decode instance gets raw tasks from engine worker queue to preallocate resources,
and decode instance gets prefilled tasks from engine worker queue to generate tokens.
TODO: unifiy the communication between decode and prefill instances.
"""
def receiver_loop():
waiting_resource_requests = []
waiting_ready_tasks = []
# Waiting for the api_server and scheduler in decode to
# receive the request sent by the client
def _decode_process_prefilled_task_v0_scheduler(input_tasks):
ready_tasks = []
waiting_tasks = []
for task in input_tasks:
if not hasattr(self.scheduler, "has_request") or self.scheduler.has_request(task.request_id):
ready_tasks.append(task)
else:
waiting_tasks.append(task)
self.insert_tasks(ready_tasks, allocated=True)
if self.cfg.splitwise_version in ("v0", "v2"):
self.scheduler.put_results(ready_tasks)
return waiting_tasks
while self.running:
try:
processed_indices = []
for idx, task in enumerate(self.waiting_requests):
for idx, task in enumerate(waiting_resource_requests):
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
@@ -1004,21 +1038,27 @@ class EngineService:
break
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
waiting_resource_requests.pop(idx)
if not self.engine_worker_queue.disaggregate_queue_empty():
waiting_ready_tasks = _decode_process_prefilled_task_v0_scheduler(waiting_ready_tasks)
if self.engine_worker_queue.disaggregate_queue_empty():
time.sleep(0.001)
else:
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
# prefill instance gets tasks from engine worker queue
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
# decode instance gets tasks from engine worker queue
elif role == "decode":
if hasattr(tasks[0], "finished"):
if isinstance(tasks[0], RequestOutput):
self.llm_logger.debug(f"receive prefilled tasks, {tasks}")
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
@@ -1057,13 +1097,12 @@ class EngineService:
self.resource_manager.insert_task_for_decoding(task)
else:
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
waiting_ready_tasks.extend(_decode_process_prefilled_task_v0_scheduler(tasks))
elif isinstance(tasks[0], Request):
self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}")
if len(waiting_resource_requests):
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
waiting_resource_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
@@ -1087,13 +1126,12 @@ class EngineService:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
waiting_resource_requests.extend(new_waiting)
self.llm_logger.info(
f"Added {len(new_waiting)} tasks to waiting queue"
)
else:
time.sleep(0.001)
else:
raise ValueError(f"Unsupported task type: {type(tasks[0])}")
except Exception as e:
self.llm_logger.error(f"Error in main loop: {e}")
@@ -1130,6 +1168,42 @@ class EngineService:
llm_logger.error(f"Clear data error: {e}")
return False
def _register_to_router(self):
"""If use router, register this server to router"""
timeout = 5
sleep_seconds = 10
def _register():
while True:
try:
time.sleep(sleep_seconds)
api_server_host = self.cfg.router_config.api_server_host
api_server_port = self.cfg.router_config.api_server_port
api_server_url = f"http://{api_server_host}:{api_server_port}"
if not check_service_health(api_server_url):
continue
router_url = self.cfg.router_config.router
resp = requests.post(
f"{router_url}/register",
json=self.cfg.register_info,
timeout=timeout,
)
if not resp.ok:
llm_logger.error(
f"Router registration failed: {resp.status_code}, "
f"{resp.text}, {self.cfg.register_info}"
)
except requests.exceptions.RequestException as e:
llm_logger.error(f"Register to router request error: {e}")
except Exception as e:
llm_logger.exception(f"Unexpected error during router registration: {e}")
if self.cfg.router_config.router is not None:
register_thread = threading.Thread(target=_register, daemon=True)
register_thread.start()
def _exit_sub_services(self):
"""
exit sub services