[Feature] optimize expert parallel (#3196)

* optimize

* Update expert_service.py

* Update worker_process.py

* optimize
This commit is contained in:
ltd0924
2025-08-05 17:34:24 +08:00
committed by GitHub
parent dcf9c2daff
commit b20ffe3697
7 changed files with 174 additions and 134 deletions

View File

@@ -28,6 +28,7 @@ import time
import traceback
import uuid
import weakref
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple
@@ -125,9 +126,17 @@ class LLMEngine:
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
self.cfg.engine_worker_queue_port + self.cfg.parallel_config.local_data_parallel_id
)
self.splitwise_queue = deque()
self.split_connector = SplitwiseConnector(
cfg,
self.scheduler,
self.engine_worker_queue,
self.resource_manager,
self.splitwise_queue,
)
self.token_processor = TokenProcessor(
cfg=self.cfg,
@@ -343,12 +352,6 @@ class LLMEngine:
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
time.sleep(0.005)
continue
if self.engine_worker_queue.num_cache_infos() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
@@ -596,43 +599,42 @@ class LLMEngine:
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
if len(self.splitwise_queue) > 0:
items = self.splitwise_queue.pop()
role = items[0]
tasks = items[1]
if role == "prefill":
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
task.finished = False
self.insert_tasks(tasks, allocated=True)
elif role == "decode":
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
if len(self.waiting_requests):
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
new_waiting.append(task)
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
new_waiting.append(task)
if new_waiting:
self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
if new_waiting:
self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else:
time.sleep(0.001)
@@ -842,7 +844,6 @@ class LLMEngine:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks:
@@ -854,6 +855,8 @@ class LLMEngine:
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
self.engine_worker_queue.available_prefill_instances.put(1)
self.split_connector.send_cache_infos(tasks, current_id)
return True
def task_is_finished(self, index):
@@ -1017,13 +1020,16 @@ class LLMEngine:
except Exception as e:
print(f"Error extracting sub services: {e}")
self.engine_worker_queue.cleanup()
for worker_queue in self.engine_worker_queue_server:
worker_queue.cleanup()
if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
self.recv_request_server.close()
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
self.recv_control_cmd_server.close()
if hasattr(self, "dp_processed"):
for p in self.dp_processed:
p.join()
@@ -1325,15 +1331,20 @@ class LLMEngine:
"""
start queue service for engine worker communication
"""
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
self.engine_worker_queue_server = list()
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
for i in range(self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port + i)
llm_logger.info(f"Starting engine worker queue service at {address}")
self.engine_worker_queue_server.append(
EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
self.cache_task_queue = EngineCacheQueue(
@@ -1348,6 +1359,7 @@ class LLMEngine:
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,