mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support ep pd with external module (#3194)
* Support external module * Support external module * Support external module * Support external module * refactor code to make it more clear * refactor code to make it more clear * refactor code to make it more clear * refactor code to make it more clear * fix according to review * fix according to review * fix according to review * fix according to review * fix according to review * fix according to review * fix bug * fix bug * fix bug * merge --------- Co-authored-by: root <root@tjdm-inf-sci-k8s-hzz2-h12ni8-0202.tjdm.baidu.com>
This commit is contained in:
@@ -47,12 +47,14 @@ from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
EngineWorkerQueue,
|
||||
IPCSignal,
|
||||
ZmqClient,
|
||||
ZmqIpcServer,
|
||||
ZmqTcpServer,
|
||||
)
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
@@ -179,11 +181,64 @@ class LLMEngine:
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
|
||||
if api_server_pid is not None:
|
||||
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
|
||||
self.zmq_server.start_server()
|
||||
self.zmq_server.create_router()
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
|
||||
self.external_adapter = InternalAdapter(
|
||||
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
|
||||
)
|
||||
else:
|
||||
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
||||
time.sleep(3)
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
request_queues_for_dp_ipc = (
|
||||
None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp
|
||||
)
|
||||
result_queue_for_dp_ipc = None
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
elif self.cfg.scheduler_config.name == "dp":
|
||||
request_queues_for_dp_ipc = []
|
||||
result_queue_for_dp_ipc = multiprocessing.Queue()
|
||||
for i in range(self.cfg.parallel_config.data_parallel_size):
|
||||
request_queues_for_dp_ipc.append(multiprocessing.Queue())
|
||||
self.scheduler.start(
|
||||
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
|
||||
)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.dp_processed = []
|
||||
for i in range(
|
||||
1,
|
||||
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
||||
):
|
||||
time.sleep(1)
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(
|
||||
target=start_expert_service,
|
||||
args=(
|
||||
self.cfg,
|
||||
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
|
||||
self.ipc_signal_suffix,
|
||||
request_queues_for_dp_ipc,
|
||||
result_queue_for_dp_ipc,
|
||||
),
|
||||
)
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
|
||||
+ f" data parallel id {i}"
|
||||
)
|
||||
self.dp_processed[-1].start()
|
||||
|
||||
if self.do_profile == 0 and (
|
||||
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
|
||||
):
|
||||
@@ -238,44 +293,11 @@ class LLMEngine:
|
||||
# 单机逻辑
|
||||
self.engine_worker_queue.available_prefill_instances.put(1)
|
||||
self.split_mode_get_tasks()
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
if self.cfg.scheduler_config.name == "splitwise" or self.cfg.scheduler_config.name == "dp":
|
||||
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
|
||||
self.splitwise_receive_thread.daemon = True
|
||||
self.splitwise_receive_thread.start()
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.dp_processed = []
|
||||
for i in range(
|
||||
1,
|
||||
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
||||
):
|
||||
time.sleep(1)
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(
|
||||
target=start_expert_service,
|
||||
args=(
|
||||
self.cfg,
|
||||
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
|
||||
self.ipc_signal_suffix,
|
||||
),
|
||||
)
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
|
||||
+ f" data parallel id {i}"
|
||||
)
|
||||
self.dp_processed[-1].start()
|
||||
|
||||
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
|
||||
return True
|
||||
|
||||
@@ -291,7 +313,7 @@ class LLMEngine:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
self.zmq_server.send_multipart(request_id, contents)
|
||||
self.send_response_server.send_response(request_id, contents)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
@@ -415,14 +437,18 @@ class LLMEngine:
|
||||
if self.api_server_pid is None:
|
||||
return
|
||||
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
if self.cfg.splitwise_role == "decode":
|
||||
return
|
||||
|
||||
added_requests: Dict[str, int] = dict()
|
||||
while self.running:
|
||||
try:
|
||||
block = True if len(added_requests) == 0 else False
|
||||
if not self.cfg.enable_mm:
|
||||
err, data = self.zmq_server.receive_json_once(block)
|
||||
err, data = self.recv_request_server.receive_json_once(block)
|
||||
else:
|
||||
err, data = self.zmq_server.receive_pyobj_once(block)
|
||||
err, data = self.recv_request_server.receive_pyobj_once(block)
|
||||
if err is not None:
|
||||
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
|
||||
break
|
||||
@@ -470,7 +496,7 @@ class LLMEngine:
|
||||
)
|
||||
# Since the request is not in scheduler
|
||||
# Send result by zmq directly
|
||||
self.zmq_server.send_multipart(request_id, error_result)
|
||||
self.send_response_server.send_response(request_id, error_result)
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
f"Error happend while receving new request from zmq, details={e}, "
|
||||
@@ -989,8 +1015,12 @@ class LLMEngine:
|
||||
print(f"Error extracting sub services: {e}")
|
||||
|
||||
self.engine_worker_queue.cleanup()
|
||||
if hasattr(self, "zmq_server") and self.zmq_server is not None:
|
||||
self.zmq_server.close()
|
||||
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
||||
self.send_response_server.close()
|
||||
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
|
||||
self.recv_request_server.close()
|
||||
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
|
||||
self.recv_control_cmd_server.close()
|
||||
if hasattr(self, "dp_processed"):
|
||||
for p in self.dp_processed:
|
||||
p.join()
|
||||
|
Reference in New Issue
Block a user