mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] optimize expert parallel (#3196)
* optimize * Update expert_service.py * Update worker_process.py * optimize
This commit is contained in:
@@ -28,6 +28,7 @@ import time
|
||||
import traceback
|
||||
import uuid
|
||||
import weakref
|
||||
from collections import deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -125,9 +126,17 @@ class LLMEngine:
|
||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||
)
|
||||
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
|
||||
|
||||
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager)
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
|
||||
self.cfg.engine_worker_queue_port + self.cfg.parallel_config.local_data_parallel_id
|
||||
)
|
||||
self.splitwise_queue = deque()
|
||||
self.split_connector = SplitwiseConnector(
|
||||
cfg,
|
||||
self.scheduler,
|
||||
self.engine_worker_queue,
|
||||
self.resource_manager,
|
||||
self.splitwise_queue,
|
||||
)
|
||||
|
||||
self.token_processor = TokenProcessor(
|
||||
cfg=self.cfg,
|
||||
@@ -343,12 +352,6 @@ class LLMEngine:
|
||||
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if self.engine_worker_queue.num_cache_infos() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if len(self.split_connector.current_request_ids) > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
@@ -596,43 +599,42 @@ class LLMEngine:
|
||||
for idx in sorted(processed_indices, reverse=True):
|
||||
self.waiting_requests.pop(idx)
|
||||
|
||||
if not self.engine_worker_queue.disaggregate_queue_empty():
|
||||
items = self.engine_worker_queue.get_disaggregated_tasks()
|
||||
for item in items:
|
||||
role = item[0]
|
||||
tasks = item[1]
|
||||
if len(self.splitwise_queue) > 0:
|
||||
items = self.splitwise_queue.pop()
|
||||
role = items[0]
|
||||
tasks = items[1]
|
||||
|
||||
if role == "prefill":
|
||||
if role == "prefill":
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
|
||||
elif role == "decode":
|
||||
if hasattr(tasks[0], "finished"):
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
task.finished = False
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
|
||||
elif role == "decode":
|
||||
if hasattr(tasks[0], "finished"):
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.finished = False
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
|
||||
else:
|
||||
if len(self.waiting_requests):
|
||||
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.waiting_requests.extend(tasks)
|
||||
else:
|
||||
if len(self.waiting_requests):
|
||||
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.waiting_requests.extend(tasks)
|
||||
else:
|
||||
new_waiting = []
|
||||
for task in tasks:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
else:
|
||||
new_waiting.append(task)
|
||||
new_waiting = []
|
||||
for task in tasks:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
else:
|
||||
new_waiting.append(task)
|
||||
|
||||
if new_waiting:
|
||||
self.waiting_requests.extend(new_waiting)
|
||||
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
||||
if new_waiting:
|
||||
self.waiting_requests.extend(new_waiting)
|
||||
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
||||
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
@@ -842,7 +844,6 @@ class LLMEngine:
|
||||
is_prefill = True
|
||||
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
|
||||
|
||||
self.split_connector.send_cache_infos(tasks, current_id)
|
||||
if not is_decode:
|
||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
for task in tasks:
|
||||
@@ -854,6 +855,8 @@ class LLMEngine:
|
||||
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
||||
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
|
||||
self.engine_worker_queue.available_prefill_instances.put(1)
|
||||
|
||||
self.split_connector.send_cache_infos(tasks, current_id)
|
||||
return True
|
||||
|
||||
def task_is_finished(self, index):
|
||||
@@ -1017,13 +1020,16 @@ class LLMEngine:
|
||||
except Exception as e:
|
||||
print(f"Error extracting sub services: {e}")
|
||||
|
||||
self.engine_worker_queue.cleanup()
|
||||
|
||||
for worker_queue in self.engine_worker_queue_server:
|
||||
worker_queue.cleanup()
|
||||
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
||||
self.send_response_server.close()
|
||||
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
|
||||
self.recv_request_server.close()
|
||||
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
|
||||
self.recv_control_cmd_server.close()
|
||||
|
||||
if hasattr(self, "dp_processed"):
|
||||
for p in self.dp_processed:
|
||||
p.join()
|
||||
@@ -1325,15 +1331,20 @@ class LLMEngine:
|
||||
"""
|
||||
start queue service for engine worker communication
|
||||
"""
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
|
||||
|
||||
self.engine_worker_queue_server = list()
|
||||
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
|
||||
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
num_client=self.cfg.tensor_parallel_size,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port + i)
|
||||
llm_logger.info(f"Starting engine worker queue service at {address}")
|
||||
self.engine_worker_queue_server.append(
|
||||
EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
num_client=self.cfg.tensor_parallel_size,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
@@ -1348,6 +1359,7 @@ class LLMEngine:
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
|
Reference in New Issue
Block a user