mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[Feature] optimize expert parallel (#3196)
* optimize * Update expert_service.py * Update worker_process.py * optimize
This commit is contained in:
@@ -127,7 +127,7 @@ class CacheMessager:
|
||||
self.gpu_cache_kvs = gpu_cache_kvs
|
||||
self.rank = rank
|
||||
self.nranks = nranks
|
||||
address = (pod_ip, engine_worker_queue_port)
|
||||
address = (pod_ip, engine_worker_queue_port + local_data_parallel_id)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
|
@@ -28,6 +28,7 @@ import time
|
||||
import traceback
|
||||
import uuid
|
||||
import weakref
|
||||
from collections import deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -125,9 +126,17 @@ class LLMEngine:
|
||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||
)
|
||||
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
|
||||
|
||||
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager)
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
|
||||
self.cfg.engine_worker_queue_port + self.cfg.parallel_config.local_data_parallel_id
|
||||
)
|
||||
self.splitwise_queue = deque()
|
||||
self.split_connector = SplitwiseConnector(
|
||||
cfg,
|
||||
self.scheduler,
|
||||
self.engine_worker_queue,
|
||||
self.resource_manager,
|
||||
self.splitwise_queue,
|
||||
)
|
||||
|
||||
self.token_processor = TokenProcessor(
|
||||
cfg=self.cfg,
|
||||
@@ -343,12 +352,6 @@ class LLMEngine:
|
||||
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if self.engine_worker_queue.num_cache_infos() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if len(self.split_connector.current_request_ids) > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
@@ -596,11 +599,10 @@ class LLMEngine:
|
||||
for idx in sorted(processed_indices, reverse=True):
|
||||
self.waiting_requests.pop(idx)
|
||||
|
||||
if not self.engine_worker_queue.disaggregate_queue_empty():
|
||||
items = self.engine_worker_queue.get_disaggregated_tasks()
|
||||
for item in items:
|
||||
role = item[0]
|
||||
tasks = item[1]
|
||||
if len(self.splitwise_queue) > 0:
|
||||
items = self.splitwise_queue.pop()
|
||||
role = items[0]
|
||||
tasks = items[1]
|
||||
|
||||
if role == "prefill":
|
||||
for task in tasks:
|
||||
@@ -842,7 +844,6 @@ class LLMEngine:
|
||||
is_prefill = True
|
||||
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
|
||||
|
||||
self.split_connector.send_cache_infos(tasks, current_id)
|
||||
if not is_decode:
|
||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
for task in tasks:
|
||||
@@ -854,6 +855,8 @@ class LLMEngine:
|
||||
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
||||
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
|
||||
self.engine_worker_queue.available_prefill_instances.put(1)
|
||||
|
||||
self.split_connector.send_cache_infos(tasks, current_id)
|
||||
return True
|
||||
|
||||
def task_is_finished(self, index):
|
||||
@@ -1017,13 +1020,16 @@ class LLMEngine:
|
||||
except Exception as e:
|
||||
print(f"Error extracting sub services: {e}")
|
||||
|
||||
self.engine_worker_queue.cleanup()
|
||||
|
||||
for worker_queue in self.engine_worker_queue_server:
|
||||
worker_queue.cleanup()
|
||||
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
||||
self.send_response_server.close()
|
||||
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
|
||||
self.recv_request_server.close()
|
||||
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
|
||||
self.recv_control_cmd_server.close()
|
||||
|
||||
if hasattr(self, "dp_processed"):
|
||||
for p in self.dp_processed:
|
||||
p.join()
|
||||
@@ -1325,15 +1331,20 @@ class LLMEngine:
|
||||
"""
|
||||
start queue service for engine worker communication
|
||||
"""
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
|
||||
|
||||
self.engine_worker_queue_server = list()
|
||||
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
|
||||
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port + i)
|
||||
llm_logger.info(f"Starting engine worker queue service at {address}")
|
||||
self.engine_worker_queue_server.append(
|
||||
EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
num_client=self.cfg.tensor_parallel_size,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
@@ -1348,6 +1359,7 @@ class LLMEngine:
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
|
@@ -22,6 +22,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -31,8 +32,7 @@ from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.output.token_processor import TokenProcessor
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
|
||||
|
||||
class ExpertService:
|
||||
"""
|
||||
@@ -53,6 +53,10 @@ class ExpertService:
|
||||
self.cfg = cfg
|
||||
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node
|
||||
end_pos = start_pos + self.cfg.tensor_parallel_size
|
||||
self.waiting_requests = []
|
||||
self.disaggregate_queue = deque()
|
||||
|
||||
self.llm_logger = get_logger("expert_service", f"expert_service_{local_data_parallel_id}.log")
|
||||
if cfg.splitwise_role != "mixed":
|
||||
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
|
||||
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
|
||||
@@ -66,7 +70,7 @@ class ExpertService:
|
||||
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
|
||||
address = (cfg.master_ip, cfg.engine_worker_queue_port)
|
||||
address = (cfg.master_ip, cfg.engine_worker_queue_port + local_data_parallel_id)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
@@ -90,10 +94,7 @@ class ExpertService:
|
||||
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
|
||||
|
||||
self.split_connector = SplitwiseConnector(
|
||||
self.cfg,
|
||||
self.scheduler,
|
||||
self.engine_worker_queue,
|
||||
self.resource_manager,
|
||||
self.cfg, self.scheduler, self.engine_worker_queue, self.resource_manager, self.disaggregate_queue
|
||||
)
|
||||
|
||||
self.token_processor = TokenProcessor(
|
||||
@@ -127,7 +128,7 @@ class ExpertService:
|
||||
# assert not self.is_started, "The engine is already started."
|
||||
start_time = time.time()
|
||||
|
||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
||||
cache_config=self.cfg.cache_config,
|
||||
@@ -177,9 +178,6 @@ class ExpertService:
|
||||
if self.engine_worker_queue.num_tasks() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if len(self.split_connector.current_request_ids) > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
@@ -200,7 +198,7 @@ class ExpertService:
|
||||
continue
|
||||
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
llm_logger.info("Inserting splitwise tasks")
|
||||
self.llm_logger.info("Inserting splitwise tasks")
|
||||
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
||||
|
||||
current_id = (current_id + 1) % 100003
|
||||
@@ -210,63 +208,72 @@ class ExpertService:
|
||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||
except Exception as e:
|
||||
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
|
||||
llm_logger.error(err_msg)
|
||||
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
||||
self.llm_logger.error(err_msg)
|
||||
|
||||
def split_mode_get_tasks(self):
|
||||
"""
|
||||
Split mode get tasks
|
||||
"""
|
||||
waiting_requests = []
|
||||
|
||||
def receiver_loop():
|
||||
while True:
|
||||
try:
|
||||
if len(waiting_requests) > 0:
|
||||
for task in waiting_requests:
|
||||
|
||||
processed_indices = []
|
||||
for idx, task in enumerate(self.waiting_requests):
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
waiting_requests.remove(task)
|
||||
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
||||
processed_indices.append(idx)
|
||||
else:
|
||||
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
||||
break
|
||||
if not self.engine_worker_queue.disaggregate_queue_empty():
|
||||
items = self.engine_worker_queue.get_disaggregated_tasks()
|
||||
for item in items:
|
||||
role = item[0]
|
||||
tasks = item[1]
|
||||
|
||||
for idx in sorted(processed_indices, reverse=True):
|
||||
self.waiting_requests.pop(idx)
|
||||
|
||||
if len(self.disaggregate_queue) > 0:
|
||||
items = self.disaggregate_queue.pop()
|
||||
role = items[0]
|
||||
tasks = items[1]
|
||||
if role == "prefill":
|
||||
llm_logger.info("get prefill tasks")
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
elif role == "decode":
|
||||
llm_logger.info(f"get decode tasks {tasks}")
|
||||
if hasattr(tasks[0], "finished"):
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.finished = False
|
||||
# self.scheduler.put_results(tasks)
|
||||
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
|
||||
if self.cfg.innode_prefill_ports is not None:
|
||||
self.scheduler.put_results(tasks)
|
||||
|
||||
else:
|
||||
if len(waiting_requests):
|
||||
if len(self.waiting_requests):
|
||||
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
|
||||
self.waiting_requests.extend(tasks)
|
||||
else:
|
||||
new_waiting = []
|
||||
for task in tasks:
|
||||
waiting_requests.append(task)
|
||||
else:
|
||||
for task in tasks:
|
||||
if not self.resource_manager.is_resource_sufficient(
|
||||
task.prompt_token_ids_len
|
||||
):
|
||||
waiting_requests.append(task)
|
||||
else:
|
||||
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
else:
|
||||
new_waiting.append(task)
|
||||
|
||||
if new_waiting:
|
||||
self.waiting_requests.extend(new_waiting)
|
||||
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
|
||||
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"get decode tasks error: {e}")
|
||||
self.llm_logger.error(f"Error in main loop: {e} {str(traceback.format_exc())}")
|
||||
time.sleep(0.1)
|
||||
|
||||
threading.Thread(target=receiver_loop, daemon=True).start()
|
||||
|
||||
@@ -287,11 +294,11 @@ class ExpertService:
|
||||
if task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[task.request_id]
|
||||
self.scheduler.put_results([task])
|
||||
llm_logger.warning(
|
||||
self.llm_logger.warning(
|
||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||
)
|
||||
continue
|
||||
llm_logger.info(f"{cur_task_idx} {task.request_id}")
|
||||
self.llm_logger.info(f"{cur_task_idx} {task.request_id}")
|
||||
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||
self.token_processor.tokens_counter[task.request_id] = 1
|
||||
current_tasks.append(cur_task)
|
||||
@@ -308,8 +315,10 @@ class ExpertService:
|
||||
|
||||
available_batch = np.sum(self.resource_manager.stop_flags)
|
||||
if len(tasks) > available_batch:
|
||||
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
|
||||
llm_logger.error("The exceeded part will be ignored!")
|
||||
self.llm_logger.error(
|
||||
"Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch)
|
||||
)
|
||||
self.llm_logger.error("The exceeded part will be ignored!")
|
||||
tasks = tasks[:available_batch]
|
||||
|
||||
req_ids = [t.request_id for t in tasks]
|
||||
@@ -318,7 +327,7 @@ class ExpertService:
|
||||
|
||||
if not tasks:
|
||||
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
|
||||
llm_logger.error(error_msg)
|
||||
self.llm_logger.error(error_msg)
|
||||
raise EngineError(error_msg, error_code=500)
|
||||
return False
|
||||
|
||||
@@ -338,7 +347,7 @@ class ExpertService:
|
||||
for task in tasks:
|
||||
task.infer_start_time = time.time()
|
||||
if not is_decode:
|
||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
|
||||
if not self.cfg.enable_mm:
|
||||
self.update_requests_chunk_size(tasks)
|
||||
@@ -356,7 +365,7 @@ class ExpertService:
|
||||
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
||||
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
||||
for p in self.cache_manager_processes:
|
||||
llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
self.llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
try:
|
||||
os.killpg(p.pid, signal.SIGTERM)
|
||||
except:
|
||||
|
@@ -22,7 +22,7 @@ import numpy as np
|
||||
|
||||
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import llm_logger
|
||||
from fastdeploy.utils import get_logger, llm_logger
|
||||
|
||||
|
||||
class ResourceManager:
|
||||
@@ -49,6 +49,12 @@ class ResourceManager:
|
||||
Initializes the engine with the given configuration and sets up necessary
|
||||
data structures to manage tasks and blocks.
|
||||
"""
|
||||
if local_data_parallel_id > 0:
|
||||
self.logger = get_logger(
|
||||
f"expert_service_{local_data_parallel_id}", f"expert_service_{local_data_parallel_id}.log"
|
||||
)
|
||||
else:
|
||||
self.logger = llm_logger
|
||||
self.cfg = config.cache_config
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.stop_flags = [True] * max_num_seqs
|
||||
@@ -58,7 +64,7 @@ class ResourceManager:
|
||||
self.req_dict = dict()
|
||||
# current batch status of the engine
|
||||
self.real_bsz = 0
|
||||
llm_logger.info(f"{self.info()}")
|
||||
self.logger.info(f"{self.info()}")
|
||||
|
||||
def reset_cache_config(self, cfg):
|
||||
"""
|
||||
@@ -134,10 +140,10 @@ class ResourceManager:
|
||||
block_list = list()
|
||||
current_block_num = self.available_block_num()
|
||||
if block_num > current_block_num:
|
||||
llm_logger.error(f"block_num:{block_num} > free_list len:{current_block_num}")
|
||||
self.logger.error("block_num:{0} > free_list len:{1}".format(block_num, current_block_num))
|
||||
return block_list
|
||||
block_list = self.cache_manager.allocate_gpu_blocks(block_num)
|
||||
llm_logger.debug(f"dispatch {len(block_list)} blocks.")
|
||||
self.logger.debug(f"dispatch {len(block_list)} blocks.")
|
||||
return block_list
|
||||
|
||||
def check_and_free_block_tables(self):
|
||||
@@ -169,7 +175,7 @@ class ResourceManager:
|
||||
self.cache_manager.recycle_gpu_blocks(block_tables)
|
||||
cur_number = self.available_block_num()
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
llm_logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
|
||||
self.logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
|
||||
|
||||
def available_batch(self):
|
||||
"""
|
||||
@@ -248,12 +254,10 @@ class ResourceManager:
|
||||
if self.enable_prefix_cache:
|
||||
cache_prepare_time = time.time()
|
||||
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
|
||||
task,
|
||||
self.cfg.block_size,
|
||||
self.cfg.dec_token_num,
|
||||
task, self.cfg.block_size, self.cfg.dec_token_num
|
||||
)
|
||||
if unique_block_ids is None:
|
||||
llm_logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
|
||||
self.logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
|
||||
return
|
||||
|
||||
cached_len = self._record_request_cache_info(
|
||||
@@ -294,7 +298,7 @@ class ResourceManager:
|
||||
task.inference_time_cost = -1.0
|
||||
task.tokens_all_num = 0
|
||||
self.tasks_list[allocated_position] = task
|
||||
llm_logger.info(
|
||||
self.logger.info(
|
||||
f"Allocate request: {task.request_id}, "
|
||||
f"allocated_position:{allocated_position}, "
|
||||
f"length of prompt token: {task.prompt_token_ids_len}"
|
||||
@@ -308,10 +312,10 @@ class ResourceManager:
|
||||
self.real_bsz = i + 1
|
||||
break
|
||||
|
||||
llm_logger.info(
|
||||
self.logger.info(
|
||||
f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}"
|
||||
)
|
||||
llm_logger.info(f"{self.info()}")
|
||||
self.logger.info(f"{self.info()}")
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
|
||||
return processed_tasks
|
||||
@@ -342,8 +346,8 @@ class ResourceManager:
|
||||
cached_len = len(common_block_ids) * self.cfg.block_size
|
||||
task.block_tables = common_block_ids + unique_block_ids
|
||||
task.need_block_tables = unique_block_ids
|
||||
llm_logger.debug(f"common: {common_block_ids} ")
|
||||
llm_logger.debug(f"unique: {unique_block_ids} ")
|
||||
self.logger.debug(f"common: {common_block_ids} ")
|
||||
self.logger.debug(f"unique: {unique_block_ids} ")
|
||||
return cached_len
|
||||
|
||||
def info(self):
|
||||
|
@@ -177,6 +177,8 @@ class OpenAIServingChat:
|
||||
for res in response:
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
if res["finished"]:
|
||||
api_server_logger.info(f"chat completion finished: {request_id}")
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res,
|
||||
|
@@ -14,11 +14,11 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
@@ -34,7 +34,7 @@ class SplitwiseConnector:
|
||||
SplitwiseConnector class for managing and scheduling Splitwise tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, scheduler, worker_queue, resource_manager):
|
||||
def __init__(self, cfg, scheduler, worker_queue, resource_manager, splitwise_queue):
|
||||
"""
|
||||
Initialize the SplitwiseConnector instance.
|
||||
|
||||
@@ -51,6 +51,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances = {}
|
||||
self.temp_cache_info = dict()
|
||||
self.current_request_ids = dict()
|
||||
self.splitwise_queue = splitwise_queue
|
||||
|
||||
if self.cfg.cache_config.pd_comm_port is not None:
|
||||
self.zmq_ctx = zmq.Context()
|
||||
@@ -406,13 +407,19 @@ class SplitwiseConnector:
|
||||
if msg_type == "decode" or msg_type == "prefill":
|
||||
payload = [output.to_dict() for output in payload]
|
||||
|
||||
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
|
||||
req_ids = [task["request_id"] for task in payload]
|
||||
logger.info(f"send message {msg_type} {req_ids}")
|
||||
|
||||
json_data = msgpack.packb({"type": msg_type, "payload": payload})
|
||||
|
||||
return json_data
|
||||
|
||||
def _deserialize_message(self, data: bytes):
|
||||
|
||||
# JSON反序列化
|
||||
message = json.loads(data.decode("utf-8"))
|
||||
message = msgpack.unpackb(data)
|
||||
req_ids = [task["request_id"] for task in message["payload"]]
|
||||
logger.info(f"send message {message['type']} {req_ids}")
|
||||
return message["type"], message["payload"]
|
||||
|
||||
def _process_message(self, message: bytes):
|
||||
@@ -441,7 +448,9 @@ class SplitwiseConnector:
|
||||
"""
|
||||
|
||||
tasks_data = [Request.from_dict(task) for task in tasks]
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
|
||||
req_ids = [task["request_id"] for task in tasks]
|
||||
self.splitwise_queue.append(("decode", tasks_data))
|
||||
logger.debug(f"{req_ids} received prefill data")
|
||||
|
||||
def _handle_decode(self, payload):
|
||||
"""
|
||||
@@ -460,4 +469,6 @@ class SplitwiseConnector:
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||
req_ids = [task["request_id"] for task in payload]
|
||||
self.splitwise_queue.append(("decode", tasks))
|
||||
logger.debug(f"{req_ids} received decode data")
|
||||
|
@@ -150,7 +150,7 @@ class PaddleDisWorkerProc:
|
||||
# Initialize task queue
|
||||
task_address = (
|
||||
self.parallel_config.pod_ip,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
self.parallel_config.engine_worker_queue_port + self.parallel_config.expert_parallel_rank,
|
||||
)
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
self.task_queue = TaskQueue(
|
||||
@@ -252,9 +252,11 @@ class PaddleDisWorkerProc:
|
||||
for req_dict, bsz in tasks:
|
||||
num_running_requests = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
req_ids = [req.request_id for req in req_dicts]
|
||||
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, "
|
||||
f"num_insert_requests: {len(req_dicts)}"
|
||||
f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}"
|
||||
)
|
||||
# Process prefill inputs
|
||||
self.worker.preprocess_new_task(req_dicts)
|
||||
|
Reference in New Issue
Block a user