[Feature] optimize expert parallel (#3196)

* optimize

* Update expert_service.py

* Update worker_process.py

* optimize
This commit is contained in:
ltd0924
2025-08-05 17:34:24 +08:00
committed by GitHub
parent dcf9c2daff
commit b20ffe3697
7 changed files with 174 additions and 134 deletions

View File

@@ -127,7 +127,7 @@ class CacheMessager:
self.gpu_cache_kvs = gpu_cache_kvs self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank self.rank = rank
self.nranks = nranks self.nranks = nranks
address = (pod_ip, engine_worker_queue_port) address = (pod_ip, engine_worker_queue_port + local_data_parallel_id)
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
address=address, address=address,
is_server=False, is_server=False,

View File

@@ -28,6 +28,7 @@ import time
import traceback import traceback
import uuid import uuid
import weakref import weakref
from collections import deque
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@@ -125,9 +126,17 @@ class LLMEngine:
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
) )
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port) os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
self.cfg.engine_worker_queue_port + self.cfg.parallel_config.local_data_parallel_id
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager) )
self.splitwise_queue = deque()
self.split_connector = SplitwiseConnector(
cfg,
self.scheduler,
self.engine_worker_queue,
self.resource_manager,
self.splitwise_queue,
)
self.token_processor = TokenProcessor( self.token_processor = TokenProcessor(
cfg=self.cfg, cfg=self.cfg,
@@ -343,12 +352,6 @@ class LLMEngine:
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks(): if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
time.sleep(0.005) time.sleep(0.005)
continue continue
if self.engine_worker_queue.num_cache_infos() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min( num_prefill_batch = min(
int(self.resource_manager.available_batch()), int(self.resource_manager.available_batch()),
@@ -596,11 +599,10 @@ class LLMEngine:
for idx in sorted(processed_indices, reverse=True): for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx) self.waiting_requests.pop(idx)
if not self.engine_worker_queue.disaggregate_queue_empty(): if len(self.splitwise_queue) > 0:
items = self.engine_worker_queue.get_disaggregated_tasks() items = self.splitwise_queue.pop()
for item in items: role = items[0]
role = item[0] tasks = items[1]
tasks = item[1]
if role == "prefill": if role == "prefill":
for task in tasks: for task in tasks:
@@ -842,7 +844,6 @@ class LLMEngine:
is_prefill = True is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
if not is_decode: if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks: for task in tasks:
@@ -854,6 +855,8 @@ class LLMEngine:
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
if is_prefill and self.cfg.scheduler_config.name != "splitwise": if is_prefill and self.cfg.scheduler_config.name != "splitwise":
self.engine_worker_queue.available_prefill_instances.put(1) self.engine_worker_queue.available_prefill_instances.put(1)
self.split_connector.send_cache_infos(tasks, current_id)
return True return True
def task_is_finished(self, index): def task_is_finished(self, index):
@@ -1017,13 +1020,16 @@ class LLMEngine:
except Exception as e: except Exception as e:
print(f"Error extracting sub services: {e}") print(f"Error extracting sub services: {e}")
self.engine_worker_queue.cleanup()
for worker_queue in self.engine_worker_queue_server:
worker_queue.cleanup()
if hasattr(self, "send_response_server") and self.send_response_server is not None: if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.send_response_server.close() self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None: if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
self.recv_request_server.close() self.recv_request_server.close()
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
self.recv_control_cmd_server.close() self.recv_control_cmd_server.close()
if hasattr(self, "dp_processed"): if hasattr(self, "dp_processed"):
for p in self.dp_processed: for p in self.dp_processed:
p.join() p.join()
@@ -1325,15 +1331,20 @@ class LLMEngine:
""" """
start queue service for engine worker communication start queue service for engine worker communication
""" """
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
self.engine_worker_queue_server = list()
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0": if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
llm_logger.info(f"Starting engine worker queue server service at {address}") for i in range(self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
self.engine_worker_queue_server = EngineWorkerQueue( address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port + i)
llm_logger.info(f"Starting engine worker queue service at {address}")
self.engine_worker_queue_server.append(
EngineWorkerQueue(
address=address, address=address,
is_server=True, is_server=True,
num_client=self.cfg.tensor_parallel_size, num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
) )
)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
self.cache_task_queue = EngineCacheQueue( self.cache_task_queue = EngineCacheQueue(
@@ -1348,6 +1359,7 @@ class LLMEngine:
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
) )
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
address=address, address=address,
is_server=False, is_server=False,

View File

@@ -22,6 +22,7 @@ import threading
import time import time
import traceback import traceback
import weakref import weakref
from collections import deque
import numpy as np import numpy as np
@@ -31,8 +32,7 @@ from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
class ExpertService: class ExpertService:
""" """
@@ -53,6 +53,10 @@ class ExpertService:
self.cfg = cfg self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node
end_pos = start_pos + self.cfg.tensor_parallel_size end_pos = start_pos + self.cfg.tensor_parallel_size
self.waiting_requests = []
self.disaggregate_queue = deque()
self.llm_logger = get_logger("expert_service", f"expert_service_{local_data_parallel_id}.log")
if cfg.splitwise_role != "mixed": if cfg.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos] self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos] self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
@@ -66,7 +70,7 @@ class ExpertService:
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
address = (cfg.master_ip, cfg.engine_worker_queue_port) address = (cfg.master_ip, cfg.engine_worker_queue_port + local_data_parallel_id)
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
address=address, address=address,
is_server=False, is_server=False,
@@ -90,10 +94,7 @@ class ExpertService:
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]] self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.split_connector = SplitwiseConnector( self.split_connector = SplitwiseConnector(
self.cfg, self.cfg, self.scheduler, self.engine_worker_queue, self.resource_manager, self.disaggregate_queue
self.scheduler,
self.engine_worker_queue,
self.resource_manager,
) )
self.token_processor = TokenProcessor( self.token_processor = TokenProcessor(
@@ -127,7 +128,7 @@ class ExpertService:
# assert not self.is_started, "The engine is already started." # assert not self.is_started, "The engine is already started."
start_time = time.time() start_time = time.time()
llm_logger.info(f"start expert service {local_data_parallel_id}") self.llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config, cache_config=self.cfg.cache_config,
@@ -177,9 +178,6 @@ class ExpertService:
if self.engine_worker_queue.num_tasks() > 0: if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001) time.sleep(0.001)
continue continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min( num_prefill_batch = min(
int(self.resource_manager.available_batch()), int(self.resource_manager.available_batch()),
@@ -200,7 +198,7 @@ class ExpertService:
continue continue
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks") self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id) self.split_connector.send_splitwise_tasks(tasks, current_id)
current_id = (current_id + 1) % 100003 current_id = (current_id + 1) % 100003
@@ -210,63 +208,72 @@ class ExpertService:
main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e: except Exception as e:
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}." err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
llm_logger.error(err_msg) self.llm_logger.error(err_msg)
def split_mode_get_tasks(self): def split_mode_get_tasks(self):
""" """
Split mode get tasks Split mode get tasks
""" """
waiting_requests = []
def receiver_loop(): def receiver_loop():
while True: while True:
try: try:
if len(waiting_requests) > 0:
for task in waiting_requests: processed_indices = []
for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task]) self.insert_tasks([task])
waiting_requests.remove(task) self.llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else: else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break break
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks() for idx in sorted(processed_indices, reverse=True):
for item in items: self.waiting_requests.pop(idx)
role = item[0]
tasks = item[1] if len(self.disaggregate_queue) > 0:
items = self.disaggregate_queue.pop()
role = items[0]
tasks = items[1]
if role == "prefill": if role == "prefill":
llm_logger.info("get prefill tasks")
for task in tasks: for task in tasks:
task.max_tokens = task.min_tokens = 2 task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks) self.insert_tasks(tasks)
elif role == "decode": elif role == "decode":
llm_logger.info(f"get decode tasks {tasks}")
if hasattr(tasks[0], "finished"): if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list): if not isinstance(tasks, list):
tasks = [tasks] tasks = [tasks]
for task in tasks: for task in tasks:
task.finished = False task.finished = False
# self.scheduler.put_results(tasks)
self.insert_tasks(tasks, allocated=True) self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else: else:
if len(waiting_requests): if len(self.waiting_requests):
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
new_waiting = []
for task in tasks: for task in tasks:
waiting_requests.append(task) if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
else:
for task in tasks:
if not self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len
):
waiting_requests.append(task)
else:
self.insert_tasks([task]) self.insert_tasks([task])
else:
new_waiting.append(task)
if new_waiting:
self.waiting_requests.extend(new_waiting)
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else: else:
time.sleep(0.001) time.sleep(0.001)
continue
except Exception as e: except Exception as e:
llm_logger.error(f"get decode tasks error: {e}") self.llm_logger.error(f"Error in main loop: {e} {str(traceback.format_exc())}")
time.sleep(0.1)
threading.Thread(target=receiver_loop, daemon=True).start() threading.Thread(target=receiver_loop, daemon=True).start()
@@ -287,11 +294,11 @@ class ExpertService:
if task.request_id in self.token_processor.tokens_counter: if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id] del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task]) self.scheduler.put_results([task])
llm_logger.warning( self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
) )
continue continue
llm_logger.info(f"{cur_task_idx} {task.request_id}") self.llm_logger.info(f"{cur_task_idx} {task.request_id}")
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
self.token_processor.tokens_counter[task.request_id] = 1 self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task) current_tasks.append(cur_task)
@@ -308,8 +315,10 @@ class ExpertService:
available_batch = np.sum(self.resource_manager.stop_flags) available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch: if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") self.llm_logger.error(
llm_logger.error("The exceeded part will be ignored!") "Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch)
)
self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch] tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks] req_ids = [t.request_id for t in tasks]
@@ -318,7 +327,7 @@ class ExpertService:
if not tasks: if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}." error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
llm_logger.error(error_msg) self.llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500) raise EngineError(error_msg, error_code=500)
return False return False
@@ -338,7 +347,7 @@ class ExpertService:
for task in tasks: for task in tasks:
task.infer_start_time = time.time() task.infer_start_time = time.time()
if not is_decode: if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill: if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
if not self.cfg.enable_mm: if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks) self.update_requests_chunk_size(tasks)
@@ -356,7 +365,7 @@ class ExpertService:
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear() self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes: for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}") self.llm_logger.info(f"Killing cache manager process {p.pid}")
try: try:
os.killpg(p.pid, signal.SIGTERM) os.killpg(p.pid, signal.SIGTERM)
except: except:

View File

@@ -22,7 +22,7 @@ import numpy as np
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger from fastdeploy.utils import get_logger, llm_logger
class ResourceManager: class ResourceManager:
@@ -49,6 +49,12 @@ class ResourceManager:
Initializes the engine with the given configuration and sets up necessary Initializes the engine with the given configuration and sets up necessary
data structures to manage tasks and blocks. data structures to manage tasks and blocks.
""" """
if local_data_parallel_id > 0:
self.logger = get_logger(
f"expert_service_{local_data_parallel_id}", f"expert_service_{local_data_parallel_id}.log"
)
else:
self.logger = llm_logger
self.cfg = config.cache_config self.cfg = config.cache_config
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.stop_flags = [True] * max_num_seqs self.stop_flags = [True] * max_num_seqs
@@ -58,7 +64,7 @@ class ResourceManager:
self.req_dict = dict() self.req_dict = dict()
# current batch status of the engine # current batch status of the engine
self.real_bsz = 0 self.real_bsz = 0
llm_logger.info(f"{self.info()}") self.logger.info(f"{self.info()}")
def reset_cache_config(self, cfg): def reset_cache_config(self, cfg):
""" """
@@ -134,10 +140,10 @@ class ResourceManager:
block_list = list() block_list = list()
current_block_num = self.available_block_num() current_block_num = self.available_block_num()
if block_num > current_block_num: if block_num > current_block_num:
llm_logger.error(f"block_num:{block_num} > free_list len:{current_block_num}") self.logger.error("block_num:{0} > free_list len:{1}".format(block_num, current_block_num))
return block_list return block_list
block_list = self.cache_manager.allocate_gpu_blocks(block_num) block_list = self.cache_manager.allocate_gpu_blocks(block_num)
llm_logger.debug(f"dispatch {len(block_list)} blocks.") self.logger.debug(f"dispatch {len(block_list)} blocks.")
return block_list return block_list
def check_and_free_block_tables(self): def check_and_free_block_tables(self):
@@ -169,7 +175,7 @@ class ResourceManager:
self.cache_manager.recycle_gpu_blocks(block_tables) self.cache_manager.recycle_gpu_blocks(block_tables)
cur_number = self.available_block_num() cur_number = self.available_block_num()
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
llm_logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.") self.logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
def available_batch(self): def available_batch(self):
""" """
@@ -248,12 +254,10 @@ class ResourceManager:
if self.enable_prefix_cache: if self.enable_prefix_cache:
cache_prepare_time = time.time() cache_prepare_time = time.time()
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids( common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
task, task, self.cfg.block_size, self.cfg.dec_token_num
self.cfg.block_size,
self.cfg.dec_token_num,
) )
if unique_block_ids is None: if unique_block_ids is None:
llm_logger.warning("req_id: {0} not enough blocks available".format(task["req_id"])) self.logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
return return
cached_len = self._record_request_cache_info( cached_len = self._record_request_cache_info(
@@ -294,7 +298,7 @@ class ResourceManager:
task.inference_time_cost = -1.0 task.inference_time_cost = -1.0
task.tokens_all_num = 0 task.tokens_all_num = 0
self.tasks_list[allocated_position] = task self.tasks_list[allocated_position] = task
llm_logger.info( self.logger.info(
f"Allocate request: {task.request_id}, " f"Allocate request: {task.request_id}, "
f"allocated_position:{allocated_position}, " f"allocated_position:{allocated_position}, "
f"length of prompt token: {task.prompt_token_ids_len}" f"length of prompt token: {task.prompt_token_ids_len}"
@@ -308,10 +312,10 @@ class ResourceManager:
self.real_bsz = i + 1 self.real_bsz = i + 1
break break
llm_logger.info( self.logger.info(
f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}" f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}"
) )
llm_logger.info(f"{self.info()}") self.logger.info(f"{self.info()}")
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
return processed_tasks return processed_tasks
@@ -342,8 +346,8 @@ class ResourceManager:
cached_len = len(common_block_ids) * self.cfg.block_size cached_len = len(common_block_ids) * self.cfg.block_size
task.block_tables = common_block_ids + unique_block_ids task.block_tables = common_block_ids + unique_block_ids
task.need_block_tables = unique_block_ids task.need_block_tables = unique_block_ids
llm_logger.debug(f"common: {common_block_ids} ") self.logger.debug(f"common: {common_block_ids} ")
llm_logger.debug(f"unique: {unique_block_ids} ") self.logger.debug(f"unique: {unique_block_ids} ")
return cached_len return cached_len
def info(self): def info(self):

View File

@@ -177,6 +177,8 @@ class OpenAIServingChat:
for res in response: for res in response:
if res.get("error_code", 200) != 200: if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"])) raise ValueError("{}".format(res["error_msg"]))
if res["finished"]:
api_server_logger.info(f"chat completion finished: {request_id}")
self.engine_client.data_processor.process_response_dict( self.engine_client.data_processor.process_response_dict(
res, res,

View File

@@ -14,11 +14,11 @@
# limitations under the License. # limitations under the License.
""" """
import json
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict from typing import Dict
import msgpack
import zmq import zmq
from fastdeploy import envs from fastdeploy import envs
@@ -34,7 +34,7 @@ class SplitwiseConnector:
SplitwiseConnector class for managing and scheduling Splitwise tasks. SplitwiseConnector class for managing and scheduling Splitwise tasks.
""" """
def __init__(self, cfg, scheduler, worker_queue, resource_manager): def __init__(self, cfg, scheduler, worker_queue, resource_manager, splitwise_queue):
""" """
Initialize the SplitwiseConnector instance. Initialize the SplitwiseConnector instance.
@@ -51,6 +51,7 @@ class SplitwiseConnector:
self.connect_innode_instances = {} self.connect_innode_instances = {}
self.temp_cache_info = dict() self.temp_cache_info = dict()
self.current_request_ids = dict() self.current_request_ids = dict()
self.splitwise_queue = splitwise_queue
if self.cfg.cache_config.pd_comm_port is not None: if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context() self.zmq_ctx = zmq.Context()
@@ -406,13 +407,19 @@ class SplitwiseConnector:
if msg_type == "decode" or msg_type == "prefill": if msg_type == "decode" or msg_type == "prefill":
payload = [output.to_dict() for output in payload] payload = [output.to_dict() for output in payload]
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8") req_ids = [task["request_id"] for task in payload]
logger.info(f"send message {msg_type} {req_ids}")
json_data = msgpack.packb({"type": msg_type, "payload": payload})
return json_data return json_data
def _deserialize_message(self, data: bytes): def _deserialize_message(self, data: bytes):
# JSON反序列化 # JSON反序列化
message = json.loads(data.decode("utf-8")) message = msgpack.unpackb(data)
req_ids = [task["request_id"] for task in message["payload"]]
logger.info(f"send message {message['type']} {req_ids}")
return message["type"], message["payload"] return message["type"], message["payload"]
def _process_message(self, message: bytes): def _process_message(self, message: bytes):
@@ -441,7 +448,9 @@ class SplitwiseConnector:
""" """
tasks_data = [Request.from_dict(task) for task in tasks] tasks_data = [Request.from_dict(task) for task in tasks]
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data)) req_ids = [task["request_id"] for task in tasks]
self.splitwise_queue.append(("decode", tasks_data))
logger.debug(f"{req_ids} received prefill data")
def _handle_decode(self, payload): def _handle_decode(self, payload):
""" """
@@ -460,4 +469,6 @@ class SplitwiseConnector:
finished=True, finished=True,
) )
) )
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks)) req_ids = [task["request_id"] for task in payload]
self.splitwise_queue.append(("decode", tasks))
logger.debug(f"{req_ids} received decode data")

View File

@@ -150,7 +150,7 @@ class PaddleDisWorkerProc:
# Initialize task queue # Initialize task queue
task_address = ( task_address = (
self.parallel_config.pod_ip, self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port, self.parallel_config.engine_worker_queue_port + self.parallel_config.expert_parallel_rank,
) )
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
self.task_queue = TaskQueue( self.task_queue = TaskQueue(
@@ -252,9 +252,11 @@ class PaddleDisWorkerProc:
for req_dict, bsz in tasks: for req_dict, bsz in tasks:
num_running_requests = int(bsz) num_running_requests = int(bsz)
req_dicts.extend(req_dict) req_dicts.extend(req_dict)
req_ids = [req.request_id for req in req_dicts]
logger.info( logger.info(
f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, " f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, "
f"num_insert_requests: {len(req_dicts)}" f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}"
) )
# Process prefill inputs # Process prefill inputs
self.worker.preprocess_new_task(req_dicts) self.worker.preprocess_new_task(req_dicts)