[Feature] optimize expert parallel (#3196)

* optimize

* Update expert_service.py

* Update worker_process.py

* optimize
This commit is contained in:
ltd0924
2025-08-05 17:34:24 +08:00
committed by GitHub
parent dcf9c2daff
commit b20ffe3697
7 changed files with 174 additions and 134 deletions

View File

@@ -22,6 +22,7 @@ import threading
import time
import traceback
import weakref
from collections import deque
import numpy as np
@@ -31,8 +32,7 @@ from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
class ExpertService:
"""
@@ -53,6 +53,10 @@ class ExpertService:
self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node
end_pos = start_pos + self.cfg.tensor_parallel_size
self.waiting_requests = []
self.disaggregate_queue = deque()
self.llm_logger = get_logger("expert_service", f"expert_service_{local_data_parallel_id}.log")
if cfg.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
@@ -66,7 +70,7 @@ class ExpertService:
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
address = (cfg.master_ip, cfg.engine_worker_queue_port)
address = (cfg.master_ip, cfg.engine_worker_queue_port + local_data_parallel_id)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
@@ -90,10 +94,7 @@ class ExpertService:
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.split_connector = SplitwiseConnector(
self.cfg,
self.scheduler,
self.engine_worker_queue,
self.resource_manager,
self.cfg, self.scheduler, self.engine_worker_queue, self.resource_manager, self.disaggregate_queue
)
self.token_processor = TokenProcessor(
@@ -127,7 +128,7 @@ class ExpertService:
# assert not self.is_started, "The engine is already started."
start_time = time.time()
llm_logger.info(f"start expert service {local_data_parallel_id}")
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.splitwise_role != "mixed":
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
@@ -177,9 +178,6 @@ class ExpertService:
if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
@@ -200,7 +198,7 @@ class ExpertService:
continue
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
current_id = (current_id + 1) % 100003
@@ -210,63 +208,72 @@ class ExpertService:
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg)
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
self.llm_logger.error(err_msg)
def split_mode_get_tasks(self):
"""
Split mode get tasks
"""
waiting_requests = []
def receiver_loop():
while True:
try:
if len(waiting_requests) > 0:
for task in waiting_requests:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
waiting_requests.remove(task)
else:
break
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
if role == "prefill":
llm_logger.info("get prefill tasks")
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
llm_logger.info(f"get decode tasks {tasks}")
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
# self.scheduler.put_results(tasks)
self.insert_tasks(tasks, allocated=True)
processed_indices = []
for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
if len(self.disaggregate_queue) > 0:
items = self.disaggregate_queue.pop()
role = items[0]
tasks = items[1]
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
if len(waiting_requests):
for task in tasks:
waiting_requests.append(task)
else:
for task in tasks:
if not self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len
):
waiting_requests.append(task)
else:
self.insert_tasks([task])
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
new_waiting.append(task)
if new_waiting:
self.waiting_requests.extend(new_waiting)
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else:
time.sleep(0.001)
continue
except Exception as e:
llm_logger.error(f"get decode tasks error: {e}")
self.llm_logger.error(f"Error in main loop: {e} {str(traceback.format_exc())}")
time.sleep(0.1)
threading.Thread(target=receiver_loop, daemon=True).start()
@@ -287,11 +294,11 @@ class ExpertService:
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
llm_logger.info(f"{cur_task_idx} {task.request_id}")
self.llm_logger.info(f"{cur_task_idx} {task.request_id}")
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
@@ -308,8 +315,10 @@ class ExpertService:
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
self.llm_logger.error(
"Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch)
)
self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
@@ -318,7 +327,7 @@ class ExpertService:
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
llm_logger.error(error_msg)
self.llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
@@ -338,7 +347,7 @@ class ExpertService:
for task in tasks:
task.infer_start_time = time.time()
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
@@ -356,7 +365,7 @@ class ExpertService:
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
self.llm_logger.info(f"Killing cache manager process {p.pid}")
try:
os.killpg(p.pid, signal.SIGTERM)
except: