From 3dbe5596e623044acc77f6f39fefd04b7d9dd6ab Mon Sep 17 00:00:00 2001 From: kevin Date: Fri, 7 Nov 2025 15:42:29 +0800 Subject: [PATCH] [Feature] Support eplb for ep (#4786) * support eplb for ep * update code * update code * update code * update code * update code * update code * update code * update code * update code --- fastdeploy/config.py | 58 +- fastdeploy/engine/args_utils.py | 43 +- fastdeploy/engine/common_engine.py | 7 + fastdeploy/engine/engine.py | 1 + fastdeploy/entrypoints/engine_client.py | 324 ++++++++++- fastdeploy/entrypoints/openai/api_server.py | 33 ++ fastdeploy/eplb/__init__.py | 15 + fastdeploy/eplb/async_expert_loader.py | 427 +++++++++++++++ fastdeploy/eplb/eplb.py | 291 ++++++++++ fastdeploy/eplb/experts_manager.py | 503 ++++++++++++++++++ fastdeploy/eplb/utils.py | 160 ++++++ fastdeploy/inter_communicator/__init__.py | 2 + fastdeploy/inter_communicator/ipc_signal.py | 48 +- .../inter_communicator/ipc_signal_const.py | 26 + .../model_executor/models/ernie4_5_moe.py | 2 +- fastdeploy/rl/rollout_config.py | 2 + fastdeploy/worker/worker_process.py | 129 ++++- requirements.txt | 2 + 18 files changed, 2048 insertions(+), 25 deletions(-) create mode 100644 fastdeploy/eplb/__init__.py create mode 100644 fastdeploy/eplb/async_expert_loader.py create mode 100644 fastdeploy/eplb/eplb.py create mode 100644 fastdeploy/eplb/experts_manager.py create mode 100644 fastdeploy/eplb/utils.py create mode 100644 fastdeploy/inter_communicator/ipc_signal_const.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 1c52f27d2..fb838d045 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -124,7 +124,6 @@ class ModelConfig: self.max_model_len = 0 self.dtype = "" self.enable_logprob = False - self.enable_redundant_experts = False self.redundant_experts_num = 0 self.seed = 0 self.quantization = None @@ -247,6 +246,60 @@ class ModelConfig: logger.info("=============================================================") +class EPLBConfig: + """ + Configuration for EPLB manager. + """ + + def __init__( + self, + args, + ): + # enable eplb + self.enable_eplb: bool = False + # redundant experts num + self.redundant_experts_num: int = 0 + # expert ip shm size + self.redundant_expert_ip_shm_size: int = 1024 + # expert meta dir + self.redundant_expert_meta_dir: str = "/tmp/redundant_expert_meta" + # expert api user and password + self.redundant_expert_api_user: str = "" + self.redundant_expert_api_password: str = "" + # expert eplb strategy + self.redundant_expert_eplb_strategy: str = "" + # expert dump workload interval + self.redundant_expert_dump_workload_interval: int = 10 + # expert async load model shmem size gb + self.redundant_expert_async_load_model_shmem_size_gb: int = 0 + # expert enable schedule cordon + self.redundant_expert_enable_schedule_cordon: bool = True + # model use safetensors + self.model_use_safetensors: bool = True + # model use offline quant + self.model_use_offline_quant: bool = True + # moe quant type + self.moe_quant_type: str = "w4a8" + for key, value in args.items(): + if hasattr(self, key): + setattr(self, key, value) + + def to_json_string(self): + """ + Convert eplb_config to json string. + """ + return json.dumps({key: value for key, value in self.__dict__.items() if value is not None}) + + def print(self): + """ + Print all configuration information. + """ + logger.info("EPLB Configuration Information :") + for k, v in self.__dict__.items(): + logger.info("{:<20}:{:<6}{}".format(k, "", v)) + logger.info("=============================================================") + + class ParallelConfig: """Configuration for the distributed execution.""" @@ -1141,6 +1194,7 @@ class FDConfig: reasoning_parser: str = None, guided_decoding_backend: Optional[str] = None, disable_any_whitespace: bool = False, + eplb_config: EPLBConfig = None, early_stop_config: Optional[Dict[str, Any]] = None, tool_parser: str = None, test_mode=False, @@ -1159,6 +1213,7 @@ class FDConfig: self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config self.decoding_config: DecodingConfig = decoding_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore + self.eplb_config: Optional[EPLBConfig] = eplb_config self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config self.enable_attention_dp_balance = enable_attention_dp_balance self.attention_dp_time_out_iters = attention_dp_time_out_iters @@ -1386,6 +1441,7 @@ class FDConfig: or k == "scheduler_config" or k == "parallel_config" or k == "commit_config" + or k == "eplb_config" ): if v is not None: v.print() diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 7483a74da..e41154cd0 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -26,6 +26,7 @@ from fastdeploy import envs from fastdeploy.config import ( CacheConfig, EarlyStopConfig, + EPLBConfig, FDConfig, GraphOptimizationConfig, LoadConfig, @@ -397,6 +398,15 @@ class EngineArgs: Max waiting steps to sync all dp for prefill tasks available """ + enable_eplb: bool = False + """ + Flag to enable eplb + """ + eplb_config: Optional[Dict[str, Any]] = None + """ + Configuration for eplb. + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -693,6 +703,18 @@ class EngineArgs: default=EngineArgs.enable_expert_parallel, help="Enable expert parallelism.", ) + parallel_group.add_argument( + "--enable-eplb", + action="store_true", + default=EngineArgs.enable_eplb, + help="Enable eplb.", + ) + model_group.add_argument( + "--eplb-config", + type=json.loads, + default=EngineArgs.eplb_config, + help="Config of eplb.", + ) # Load group load_group = parser.add_argument_group("Load Configuration") @@ -1022,7 +1044,17 @@ class EngineArgs: early_stop_args[k] = v return EarlyStopConfig(early_stop_args) - def create_engine_config(self) -> FDConfig: + def create_eplb_config(self) -> EPLBConfig: + """ + Create and retuan an EPLBConfig object based on the current settings. + """ + eplb_args = asdict(self) + if self.eplb_config is not None: + for k, v in self.eplb_config.items(): + eplb_args[k] = v + return EPLBConfig(eplb_args) + + def create_engine_config(self, port_availability_check: bool = True) -> FDConfig: """ Create and return a Config object based on the current settings. """ @@ -1063,6 +1095,7 @@ class EngineArgs: graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) moba_attention_config = self.create_moba_attention_config() + eplb_cfg = self.create_eplb_config() early_stop_cfg = self.create_early_stop_config() early_stop_cfg.update_enable_early_stop(self.enable_early_stop) @@ -1072,9 +1105,10 @@ class EngineArgs: if isinstance(self.engine_worker_queue_port, str): self.engine_worker_queue_port = self.engine_worker_queue_port.split(",") - assert is_port_available( - "0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id]) - ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." + if port_availability_check: + assert is_port_available( + "0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id]) + ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." return FDConfig( model_config=model_cfg, @@ -1084,6 +1118,7 @@ class EngineArgs: load_config=load_cfg, parallel_config=parallel_cfg, max_model_len=self.max_model_len, + eplb_config=eplb_cfg, max_num_seqs=self.max_num_seqs, speculative_config=speculative_cfg, max_num_batched_tokens=self.max_num_batched_tokens, diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 7e7ccb0fd..033caf313 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -33,6 +33,7 @@ from opentelemetry import trace from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 +from fastdeploy.eplb.utils import init_eplb_signals from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import ( EngineCacheQueue, @@ -132,6 +133,12 @@ class EngineSevice: ) self._init_worker_monitor_signals() + if self.cfg.eplb_config.enable_eplb: + current_suffix = int( + self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] + ) + init_eplb_signals(cfg, current_suffix) + self._finalizer = weakref.finalize(self, self._exit_sub_services) def start(self): diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 447226fca..b26337da3 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -461,6 +461,7 @@ class LLMEngine: f" --load_choices {self.cfg.load_config.load_choices}" f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" f" --attention_dp_time_out_iters {self.cfg.attention_dp_time_out_iters}" + f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'" f" --ips {ips}" ) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 02242a31c..6e0f6bd5e 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -19,6 +19,7 @@ import os import time import traceback import uuid +from http import HTTPStatus import numpy as np @@ -26,8 +27,9 @@ from fastdeploy import envs from fastdeploy.config import ModelConfig from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS +from fastdeploy.eplb.utils import RedundantExpertWorkload from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient +from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus, ZmqIpcClient from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform @@ -49,6 +51,7 @@ class EngineClient: port, limit_mm_per_prompt, mm_processor_kwargs, + config, # enable_mm=False, reasoning_parser=None, data_parallel_size=1, @@ -59,6 +62,7 @@ class EngineClient: ): import fastdeploy.model_executor.models # noqa: F401 + self.config = config architectures = ModelConfig({"model": model_name_or_path}).architectures[0] self.enable_prefix_caching = enable_prefix_caching if MultimodalRegistry.contains_model(architectures): @@ -92,6 +96,9 @@ class EngineClient: else: self.is_master = False + if self.config.eplb_config.enable_eplb and self.config.parallel_config.expert_parallel_rank == 0: + self.init_eplb_signals(ipc_signal_suffix=port) + array_size = min(max_chips_per_node, tensor_parallel_size) self.worker_healthy_live_recorded_time_array = np.zeros(shape=[array_size], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( @@ -115,6 +122,115 @@ class EngineClient: ) self.connection_initialized = False + def init_eplb_signals(self, ipc_signal_suffix): + """ + Initialize eplb signals. + """ + self.signal_clear_experts_token_stats_list = [] + self.local_experts_token_stats_array_list = [] + self.expert_tokens_stats_array_list = [] + self.signal_update_weight_from_disk_array_list = [] + self.update_weight_from_disk_result_list = [] + rearrange_experts_status = np.zeros([1], dtype=np.int32) + self.rearrange_experts_signal = IPCSignal( + name="rearrange_experts_status", + array=rearrange_experts_status, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=False, + ) + + rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32) + self.rearrange_experts_ips_size_signal = IPCSignal( + name="rearrange_experts_ips_size", + array=rearrange_experts_ips_size_array, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=False, + ) + + self.shm_rearrange_experts_ips_list = IPCSignal( + name="rearrange_experts_ips_list", + shm_size=self.config.eplb_config.redundant_expert_ip_shm_size, + suffix=ipc_signal_suffix, + create=False, + ) + + signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) + self.signal_update_weight_from_tensor_array = IPCSignal( + name="signal_update_weight_from_tensor", + array=signal_update_weight_from_tensor, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=False, + ) + + if envs.FD_ENABLE_MULTI_API_SERVER: + engine_worker_suffix = [ + self.config.parallel_config.engine_worker_queue_port[ + self.config.parallel_config.local_data_parallel_id + ] + ] + else: + engine_worker_suffix = self.config.parallel_config.engine_worker_queue_port + + for suffix_port in engine_worker_suffix: + signal_clear_experts_token_stats = np.zeros([1], dtype=np.int32) + self.signal_clear_experts_token_stats_list.append( + IPCSignal( + name="signal_clear_experts_token_stats", + array=signal_clear_experts_token_stats, + dtype=np.int32, + suffix=suffix_port, + create=False, + ) + ) + + signal_update_weight_from_disk = np.zeros([1], dtype=np.int32) + self.signal_update_weight_from_disk_array_list.append( + IPCSignal( + name="signal_update_weight_from_disk", + array=signal_update_weight_from_disk, + dtype=np.int32, + suffix=suffix_port, + create=False, + ) + ) + + result_update_weight_from_disk = np.zeros([1], dtype=np.int32) + self.update_weight_from_disk_result_list.append( + IPCSignal( + name="result_update_weight_from_disk", + array=result_update_weight_from_disk, + dtype=np.int32, + suffix=suffix_port, + create=False, + ) + ) + + experts_token_stats = np.zeros( + (self.config.model_config.num_hidden_layers, self.config.model_config.moe_num_experts), + dtype=np.int32, + ) + self.expert_tokens_stats_array_list.append( + IPCSignal( + name="all_experts_token_stats", + array=experts_token_stats, + dtype=np.int32, + suffix=suffix_port, + create=False, + ) + ) + self.local_experts_token_stats_array_list.append( + IPCSignal( + name="local_experts_token_stats", + array=experts_token_stats, + dtype=np.int32, + suffix=suffix_port, + create=False, + ) + ) + def create_zmq_client(self, model, mode): """ Create a ZMQ client. @@ -394,3 +510,209 @@ class EngineClient: def check_model_weight_status(self): return self.model_weights_status_signal.value[0] < 0 + + async def rearrange_experts(self, request_dict: dict): + """ + rearrange experts + Args: + request_dict (dict): request body + Returns: + tuple: response body, status code + """ + content, status_code = None, HTTPStatus.OK + eplb_config = self.config.eplb_config + + if not eplb_config.enable_eplb: + content = {"code": 1, "msg": "redundant expert is disabled"} + status_code = HTTPStatus.BAD_REQUEST + return content, status_code + + if ( + request_dict.get("user", "") != eplb_config.redundant_expert_api_user + or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password + ): + content = {"code": 1, "msg": "user or passwd is invalid"} + status_code = HTTPStatus.UNAUTHORIZED + return content, status_code + + if self.config.parallel_config.expert_parallel_rank != 0: + content = { + "code": 1, + "msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0", + } + status_code = HTTPStatus.BAD_REQUEST + return content, status_code + + action = request_dict.get("action", "") + api_server_logger.info(f"redundant_expert: rearrange_experts recv request, action {action}") + if action == "": + # action: start rearrange experts + # params: {'user': 'xxx', 'passwd': 'xxx', 'ips': ['10.54.99.77:8000', '10.54.99.77:8300']} + if self.rearrange_experts_signal.value[0] != RearrangeExpertStatus.FREE.value: + content = { + "code": 1, + "msg": f"rearrange is doing. actual status {self.rearrange_experts_signal.value[0]}, expect status {RearrangeExpertStatus.FREE.value}", + } + status_code = HTTPStatus.BAD_REQUEST + if "ips" not in request_dict and content is None: + content = {"code": 1, "msg": "ips in request is None"} + status_code = HTTPStatus.BAD_REQUEST + + if content is not None: + return content, status_code + + data_bytes = (";".join(request_dict["ips"])).encode("utf-8") + data_size = len(data_bytes) + if data_size > eplb_config.redundant_expert_ip_shm_size: + content = { + "code": 1, + "msg": f"actual ips size {data_size}, max limit {eplb_config.redundant_expert_ip_shm_size}", + } + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + else: + self.rearrange_experts_ips_size_signal.value[0] = data_size + self.shm_rearrange_experts_ips_list.shm.buf[:data_size] = data_bytes + content = {"code": 0, "msg": "ok"} + status_code = HTTPStatus.OK + return content, status_code + elif action == "recv_expert_weight": + # action: receive global expert workload, and begin update weight from disk + # params: {'user': 'xxx', 'passwd': 'xxx', 'weight': (layers, experts)} + if "data" not in request_dict or not isinstance(request_dict["data"], list): + content = {"code": 1, "msg": "data not in request or data is not a list"} + status_code = HTTPStatus.BAD_REQUEST + + elif len(request_dict["data"]) != len(self.expert_tokens_stats_array_list): + content = { + "code": 1, + "msg": f"actual data length {len(request_dict['data'])}, expect length {len(self.expert_tokens_stats_array_list)}", + } + status_code = HTTPStatus.BAD_REQUEST + else: + weight = np.array(request_dict["data"], dtype=np.int32) + for idx in range(len(self.expert_tokens_stats_array_list)): + self.expert_tokens_stats_array_list[idx].value[:] = weight[:] + self.signal_update_weight_from_disk_array_list[idx].value[0] = 1 + + content = {"code": 0, "msg": "ok"} + status_code = HTTPStatus.OK + return content, status_code + elif action == "update_weight_from_tensor": + if self.cfg.scheduler_config.splitwise_role != "prefill" and content is None: + content = { + "code": 1, + "msg": f"actual role {self.cfg.scheduler_config.splitwise_role}, expect role prefill", + } + status_code = HTTPStatus.BAD_REQUEST + if self.rearrange_experts_signal.value[0] != RearrangeExpertStatus.LOAD_SUCC.value and content is None: + content = { + "code": 1, + "msg": f"actual status {self.rearrange_experts_signal.value[0]}, expect status {RearrangeExpertStatus.LOAD_SUCC.value}", + } + status_code = HTTPStatus.BAD_REQUEST + + if content is None: + self.signal_update_weight_from_tensor_array.value[0] = 1 + content = {"code": 0, "msg": "ok"} + status_code = HTTPStatus.OK + return content, status_code + else: + content = {"code": 1, "msg": f"invalid action {action}"} + status_code = HTTPStatus.BAD_REQUEST + return content, status_code + + async def get_per_expert_tokens_stats(self, request_dict: dict): + """ + get per expert tokens stats + + Args: + request_dict (dict): request body + Returns: + tuple: response body, status code + """ + content, status_code = None, HTTPStatus.OK + eplb_config = self.config.eplb_config + + if not eplb_config.enable_eplb: + content = {"code": 1, "msg": "redundant expert is disabled"} + status_code = HTTPStatus.BAD_REQUEST + return content, status_code + + if ( + request_dict.get("user", "") != eplb_config.redundant_expert_api_user + or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password + ): + content = {"code": 1, "msg": "user or passwd is invalid"} + status_code = HTTPStatus.UNAUTHORIZED + return content, status_code + + if self.config.parallel_config.expert_parallel_rank != 0: + content = { + "code": 1, + "msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0", + } + status_code = HTTPStatus.BAD_REQUEST + return content, status_code + + if "clear_stat" in request_dict and request_dict["clear_stat"]: + for clear_experts_token_stats in self.signal_clear_experts_token_stats_list: + clear_experts_token_stats.value[0] = 1 + + local_experts_list = [] + for local_experts_token_stats in self.local_experts_token_stats_array_list: + local_experts_list.append(local_experts_token_stats.value.tolist()) + content = {"code": 0, "msg": "ok", "data": local_experts_list} + status_code = HTTPStatus.OK + return content, status_code + + async def check_redundant(self, request_dict: dict): + """ + check redundant + Args: + request_dict (dict): request body + Returns: + tuple: response body, status code + """ + content, status_code = None, HTTPStatus.OK + eplb_config = self.config.eplb_config + + if not eplb_config.enable_eplb: + content = {"code": 1, "msg": "redundant expert is disabled"} + status_code = HTTPStatus.BAD_REQUEST + return content, status_code + + if ( + request_dict.get("user", "") != eplb_config.redundant_expert_api_user + or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password + ): + content = {"code": 1, "msg": "user or passwd is invalid"} + status_code = HTTPStatus.UNAUTHORIZED + return content, status_code + + if self.config.parallel_config.expert_parallel_rank != 0: + content = { + "code": 1, + "msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0", + } + status_code = HTTPStatus.BAD_REQUEST + return content, status_code + + action = request_dict.get("action", "") + if action == "": + status = "unknown" + try: + status = RearrangeExpertStatus(self.rearrange_experts_signal.value[0]).name + except: + pass + content = {"code": 0, "msg": "ok", "status": status} + get_workloads = False if "check_get_workloads" not in request_dict else request_dict["check_get_workloads"] + if get_workloads: + content["data"], content["msg"] = RedundantExpertWorkload(eplb_config.redundant_expert_meta_dir).load() + status_code = HTTPStatus.OK + elif action == "check_load_weight_result": + update_weight_from_disk_list = [] + for update_weight_result in self.update_weight_from_disk_result_list: + update_weight_from_disk_list.append(update_weight_result.value[0].tolist()) + content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list} + status_code = HTTPStatus.OK + return content, status_code diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 1d6dc65af..9c688aafb 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -155,6 +155,8 @@ async def lifespan(app: FastAPI): verification = False model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)] + engine_args = EngineArgs.from_cli_args(args) + config = engine_args.create_engine_config(port_availability_check=False) engine_client = EngineClient( model_name_or_path=args.model, tokenizer=args.tokenizer, @@ -171,6 +173,7 @@ async def lifespan(app: FastAPI): workers=args.workers, tool_parser=args.tool_call_parser, enable_prefix_caching=args.enable_prefix_caching, + config=config, ) await engine_client.connection_manager.initialize() app.state.dynamic_load_weight = args.dynamic_load_weight @@ -408,6 +411,36 @@ def clear_load_weight(request: Request) -> Response: return Response(content="Dynamic Load Weight Disabled.", status_code=404) +@app.post("/rearrange_experts") +async def rearrange_experts(request: Request): + """ + rearrange experts + """ + request_dict = await request.json() + content, status_code = await app.state.engine_client.rearrange_experts(request_dict=request_dict) + return JSONResponse(content, status_code=status_code) + + +@app.post("/get_per_expert_tokens_stats") +async def get_per_expert_tokens_stats(request: Request): + """ + get per expert tokens stats + """ + request_dict = await request.json() + content, status_code = await app.state.engine_client.get_per_expert_tokens_stats(request_dict=request_dict) + return JSONResponse(content, status_code=status_code) + + +@app.post("/check_redundant") +async def check_redundant(request: Request): + """ + check redundant + """ + request_dict = await request.json() + content, status_code = await app.state.engine_client.check_redundant(request_dict=request_dict) + return JSONResponse(content, status_code=status_code) + + def launch_api_server() -> None: """ 启动http服务 diff --git a/fastdeploy/eplb/__init__.py b/fastdeploy/eplb/__init__.py new file mode 100644 index 000000000..31be300c1 --- /dev/null +++ b/fastdeploy/eplb/__init__.py @@ -0,0 +1,15 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" diff --git a/fastdeploy/eplb/async_expert_loader.py b/fastdeploy/eplb/async_expert_loader.py new file mode 100644 index 000000000..14ca99901 --- /dev/null +++ b/fastdeploy/eplb/async_expert_loader.py @@ -0,0 +1,427 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import ctypes +import os +import time +import traceback +from typing import List, Tuple + +import numpy as np +import paddle + +from fastdeploy.config import EPLBConfig + +REARRANGE_EXPERT_MAGIC_NUM = 147183647 +REARRANGE_ORIGINATOR_EP_RANK = 0 +CHECK_TIME_INTERNAL = 3 +HTTP_RETRY_NUM = 5 +CHECK_TIMEOUT = 120 + +libc = ctypes.CDLL(None) + +libc.mmap.argtypes = [ + ctypes.c_void_p, # void *addr + ctypes.c_size_t, # size_t length + ctypes.c_int, # int prot + ctypes.c_int, # int flags + ctypes.c_int, # int fd + ctypes.c_size_t, # off_t offset +] +libc.mmap.restype = ctypes.c_void_p +libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t] +libc.munmap.restype = ctypes.c_int + +PROT_READ = 0x1 +PROT_WRITE = 0x2 +MAP_SHARED = 0x01 +MAP_ANONYMOUS = 0x20 +MAP_FAILED = -1 + +G = 1024**3 +TOTAL_MODEL_SIZE = 350 +MAIN_MODEL_REDUNDANT_SHM_SIZE = 5 + +MODEL_MAIN_NAME = "eplb_main" + + +def create_mmap(model_name: List, ep_rank: int, ep_size: int, shm_uuid: str, eplb_config: EPLBConfig, logger=None): + """create_mmap""" + flags = MAP_SHARED + prot = PROT_READ | PROT_WRITE + + main_size = 0 + if eplb_config.redundant_expert_async_load_model_shmem_size_gb == 0: + main_size = TOTAL_MODEL_SIZE // ep_size + else: + main_size = eplb_config.redundant_expert_async_load_model_shmem_size_gb + main_size = main_size * G + + mmap_infos = {} + + from cuda import cudart + + for name in model_name: + expert_weight_file = f"/dev/shm/{name}_rank_{ep_rank}_expert_weight_{shm_uuid}" + shm_size = main_size + + if not os.path.isfile(expert_weight_file): + open(expert_weight_file, "wb").close() + shm_fd = os.open(expert_weight_file, os.O_RDWR) + os.ftruncate(shm_fd, shm_size) + if logger is not None: + logger.info(f"redundant_expert: create_mmap file {expert_weight_file}, fd {shm_fd}, size {shm_size}") + + shm_ptr = libc.mmap(0, ctypes.c_size_t(shm_size), prot, flags, shm_fd, 0) + if shm_ptr == MAP_FAILED: + raise OSError(f"redundant_expert: mmap {expert_weight_file} failed: {ctypes.get_errno()}") + + shm_ptr = ctypes.cast(shm_ptr, ctypes.POINTER(ctypes.c_int8)) + addr = ctypes.addressof(shm_ptr.contents) + + # Register memory with CUDA + (ret,) = cudart.cudaHostRegister(addr, shm_size, 0) + if ret != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"cudaHostRegister failed: {cudart.cudaGetErrorString(ret)}, " + f" address {hex(addr)} size {shm_size}, ret: {ret}" + ) + + mmap_infos[name] = shm_ptr + + return mmap_infos + + +def save_tensor_to_shm_mem(cached_weights, file_path, logger=None): + """save_tensor_to_shm_mem""" + tensor_infos = [] + offset = 0 + if not os.path.exists(file_path): + raise OSError("File is not exist") + + shm_size = os.path.getsize(file_path) + + for name, w in cached_weights: + size = w.numel().item() * w.element_size() + # logger.info(f"redundant_expert: save tensor to {name} offset: {offset} size: {size}") + w_ptr = ctypes.string_at(w.data_ptr(), size) + with open(file_path, "r+b") as file: + file.seek(offset) + if offset + size > shm_size: + raise IOError( + f"redundant_expert: Exceeded {file_path} file's size. " + + "Should set a bigger value using env variable." + ) + n = file.write(w_ptr) + assert n == size + tensor_infos.append((name, offset, size, w.shape, w.dtype)) + + offset += size + + sz = offset / 1024 / 1024 / 1024 + if logger is not None: + logger.info(f"redundant_expert: save_tensor_to_shm_mem success. file {file_path} size {sz}G") + + return tensor_infos + + +def load_tensor_from_shm_mem(tensor_infos, shm_ptr, logger=None): + """load_tensor_from_shm_mem""" + # weights_dict = {} + weights_dict = [] + for name, offset, size, shape, dtype in tensor_infos: + # 计算共享内存中张量的地址 + w_addr = ctypes.cast(shm_ptr, ctypes.c_void_p).value + offset + w_ptr = ctypes.cast(w_addr, ctypes.POINTER(ctypes.c_byte)) + # 先读取为字节数组,再通过视图转换成适当类型 + np_array = np.ctypeslib.as_array(w_ptr, shape=(size,)) + + if dtype == paddle.float32: + tmp = np_array.view(np.float32) + tensor = paddle.Tensor(tmp, dtype=paddle.float32, place=paddle.CPUPlace(), zero_copy=True) + elif dtype == paddle.uint8: + tmp = np_array.view(np.uint8) + tensor = paddle.Tensor(tmp, dtype=paddle.uint8, place=paddle.CPUPlace(), zero_copy=True) + elif dtype == paddle.int8: + tmp = np_array.view(np.int8) + tensor = paddle.Tensor(tmp, dtype=paddle.int8, place=paddle.CPUPlace(), zero_copy=True) + elif dtype == paddle.bfloat16: + # NumPy 不支持 bfloat16,因此先以 uint16 读取原始数据,再用 Paddle cast 为 bfloat16 + tmp = np_array.view(np.uint16) + tensor = paddle.Tensor(tmp, dtype=paddle.bfloat16, place=paddle.CPUPlace(), zero_copy=True) + else: + raise TypeError(f"Unsupported dtype: {dtype}") + + assert w_addr == tensor.data_ptr() + # weights_dict[name] = tensor.view(shape) + weights_dict.append((name, tensor.view(shape))) + + if logger is not None: + logger.info("redundant_expert: load_tensor_from_shm_mem succ") + return weights_dict + + +class AsyncEPLoader(object): + """Aynsc Expert loader""" + + def __init__( + self, + model_dir, + eplb_config, + rank=8, + expert_per_rank=8, + moe_layer_start_index=3, + moe_quant_type="", + logger=None, + ): + """ + __init__ + """ + self.model_path = model_dir + self.eplb_config = eplb_config + + self.expert_per_rank = expert_per_rank + self.moe_layer_start_index = moe_layer_start_index + self.ep_rank = rank + self.moe_quant_type = moe_quant_type + + self.old_model_ep_rank_to_expert_id_list = None + self.new_model_ep_rank_to_expert_id_list = None + + self.cached_weights = [] + # self.state_dicts = {} + self.moe_file_names = [] + + self.logger = logger + + def reset(self): + """ + reset + """ + self.old_model_ep_rank_to_expert_id_list = None + self.new_model_ep_rank_to_expert_id_list = None + self.cached_weights = [] + self.moe_file_names = [] + + def load_experts_weight_from_disk(self): + """ + return value: (all_succ whether_load_weight exist_fatal_error message), + exist_fatal_error means all rank need restart + """ + ep_rank = self.ep_rank + start_idx = ep_rank * self.expert_per_rank + end_idx = start_idx + self.expert_per_rank + try: + old_expert_ids_all = self.old_model_ep_rank_to_expert_id_list[:, start_idx:end_idx] + new_expert_ids_all = self.new_model_ep_rank_to_expert_id_list[:, start_idx:end_idx] + need_to_reload = list() + for layer_id in range(len(old_expert_ids_all)): + if layer_id < self.moe_layer_start_index: + continue + new_expert_ids = new_expert_ids_all[layer_id] + old_expert_ids = old_expert_ids_all[layer_id] + if len(new_expert_ids) != len(old_expert_ids): + message = f"redundant_expert: new_expert_ids length not equal to old_expert_ids \ + length layer_id: {layer_id}" + # this is very dangerous and unepxpected, should be fixed + return False, message + # TODO: 按需加载,过滤重复专家 + self.logger.info( + f"redundant_expert: rank {ep_rank} layer {layer_id} old_experts {old_expert_ids}" + + f" new_experts {new_expert_ids}" + ) + need_to_reload.extend([(layer_id, expert_id) for expert_id in new_expert_ids]) + + succ = True + message = "" + if len(need_to_reload) > 0: + if self.eplb_config.model_use_safetensors: + succ, message = self.load_safetensor_fp8_from_disk(need_to_reload) + else: + succ, message = self.load_weight_bf16_from_disk(need_to_reload) + if not succ: + self.logger.info( + f"redundant_expert: load_experts_weight_from_disk fail. rank {ep_rank}, error: {message}" + ) + new_message = f"redundant_expert: load_experts_weight_from_disk fail. rank {ep_rank}, error: {message}" + return False, new_message + self.logger.info(f"redundant_expert: load_experts_weight_from_disk success. rank {ep_rank}") + return True, "redundant_expert: load_experts_weight_from_disk success" + except Exception as e: + message = f"redundant_expert: Failed to load_experts_weight_from_disk ep_rank {ep_rank} excep: {e}" + error_message = traceback.format_exc() + self.logger.error(f"redundant_expert: message: {message} traceback: {error_message}") + return False, message + + def load_weight_bf16_from_disk(self, need_to_reload: List[Tuple[int, int]]): + """load_weight_bf16_from_disk""" + try: + ckpt_up_gate_proj_name = "up_gate_proj" + ckpt_down_proj_name = "down_proj" + for layer_id, expert_id in need_to_reload: + for weight_name in [ckpt_up_gate_proj_name, ckpt_down_proj_name]: + ckpt_file_name = f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{weight_name}.weight" + if ckpt_file_name not in self.moe_file_names: + self.logger.info(f"record redundant_expert: {ckpt_file_name}") + self.moe_file_names.append(ckpt_file_name) + + last_device = paddle.device.get_device() + paddle.set_device("cpu") + + for file_name in self.moe_file_names: + # 判断文件是否存在 + if not os.path.exists(self.model_path + "/merged_tp1_state_split/" + file_name): + # self.logger.info(f"redundant_expert: {file_name} not exist.") + continue + # self.logger.info(f"redundant_expert: Loading expert weights: {file_name}.") + self.state_dicts[file_name] = paddle.load(self.model_path + "/merged_tp1_state_split/" + file_name) + + paddle.set_device(last_device) + self.logger.info("redundant_expert: Loading expert weights end.") + return True, "redundant_expert: Succeeded to loading expert weights." + except Exception as e: + message = f"redundant_expert: Failed to get weights iterator: {e}." + return False, message + + def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]): + """load_safetensor_fp8_from_disk""" + """ + ernie.layers.52.mlp.experts.58.up_gate_proj.quant_weight + ernie.layers.52.mlp.experts.58.up_gate_proj.weight_scale + ernie.layers.52.mlp.experts.58.down_proj.quant_weight + ernie.layers.52.mlp.experts.58.down_proj.weight_scale + """ + up_gate_down = ["up_gate_proj", "down_proj"] + quant_weight_scale = ["quant_weight", "weight_scale"] + if self.moe_quant_type == "w4a8": + quant_weight_scale = ["quant_weight"] + ckpt_name = [ + (f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{proj_name}.{quant_name}") + for layer_id, expert_id in need_to_reload + for proj_name in up_gate_down + for quant_name in quant_weight_scale + ] + ckpt_name_to_safetensor_file = load_ep_checkpoint(self.model_path) + hf_weights_files = list(set(ckpt_name_to_safetensor_file.values())) + state_dicts = {} + + last_device = paddle.device.get_device() + paddle.set_device("cpu") + + from safetensors import safe_open + + for st_file in hf_weights_files: + with safe_open(st_file, framework="np", device="cpu") as f: + for name in f.keys(): + if name in ckpt_name: + weight = f.get_tensor(name) + state_dicts[name] = paddle.Tensor(weight, zero_copy=True) + weights_list = [] + for name in ckpt_name: + weights_list.append((name, state_dicts[name])) + self.cached_weights = weights_list + + paddle.set_device(last_device) + return True, "load_expert_weight_from_disk_safetensor success" + + +def load_ep_checkpoint(model_path): + """ + load ep checkpoint + """ + file_path = os.path.join(model_path, "model.safetensors.index.json") + if not os.path.exists(file_path): + return {} + import json + + with open(file_path, "r") as f: + weight_map = json.load(f)["weight_map"] + state_dict = {k: os.path.join(model_path, v) for k, v in weight_map.items()} + return state_dict + + +def load_model_weights_process( + rank: int, + model_dir: str, + expert_per_rank: int, + moe_layer_start_index: int, + moe_quant_type: str, + shm_uuid: str, + eplb_config: EPLBConfig, + data_conn, + mg_conn, +): + """ + load_model_weights_process + """ + import faulthandler + + from setproctitle import setproctitle + + setproctitle(f"eplb::async_load_model_{rank}") + faulthandler.enable() + from fastdeploy.utils import get_logger + + logger = get_logger("eplb_async_loader", "eplb_{0}.log".format(rank)) + logger.info("redundant_expert: load_model_weights_process start") + + paddle.set_device("cpu") + ep_loader = AsyncEPLoader( + model_dir=model_dir, + rank=rank, + expert_per_rank=expert_per_rank, + moe_layer_start_index=moe_layer_start_index, + moe_quant_type=moe_quant_type, + logger=logger, + eplb_config=eplb_config, + ) + + while True: + ep_loader.reset() + data = mg_conn.recv() + + result = True + weight_infos = [] + try: + ep_loader.old_model_ep_rank_to_expert_id_list = data["old_model_ep_rank_to_expert_id_list"] + ep_loader.new_model_ep_rank_to_expert_id_list = data["new_model_ep_rank_to_expert_id_list"] + + begin_time_disk = int(time.time()) + success, message = ep_loader.load_experts_weight_from_disk() + begin_time_shm = int(time.time()) + logger.info( + "redundant_expert: async load load_weight_from_disk, " + + f"succ {success}, cost {begin_time_shm-begin_time_disk}s" + ) + if success: + model_name = MODEL_MAIN_NAME + file_path = f"/dev/shm/{model_name}_rank_{rank}_expert_weight_{shm_uuid}" + weight_infos = save_tensor_to_shm_mem(ep_loader.cached_weights, file_path, logger) + logger.info( + "redundant_expert: async load save_tensor_to_shm_mem, " + + f"tensor nums {len(weight_infos)}, cost {int(time.time()-begin_time_shm)}s" + ) + else: + logger.error(f"redundant_expert: async load load_weight_from_disk failed, error {message}") + result = False + + except Exception as e: + logger.error(f"redundant_expert: async load weights failed, rank {rank} error {e}") + result = False + weight_infos = [] + finally: + request_data = {"result": result, "weights": weight_infos} + data_conn.send(request_data) diff --git a/fastdeploy/eplb/eplb.py b/fastdeploy/eplb/eplb.py new file mode 100644 index 000000000..827b878d8 --- /dev/null +++ b/fastdeploy/eplb/eplb.py @@ -0,0 +1,291 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Tuple + +import numpy as np + + +def balanced_packing(weight: np.ndarray, num_packs: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs + are as balanced as possible. + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = np.arange(weight.shape[-1], dtype=np.int32).reshape(1, -1).repeat(num_layers, axis=0) + rank_in_pack = np.zeros_like(weight, dtype=np.int32) + return pack_index, rank_in_pack + + indices = np.argsort(-weight.astype(np.float32), axis=-1) + pack_index = np.full_like(weight, fill_value=-1, dtype=np.int32) + rank_in_pack = np.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min( + (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__, + ) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts(weight: np.ndarray, num_phy: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + phy2log = np.arange(num_phy, dtype=np.int32).reshape(1, -1).repeat(n, axis=0) + rank = np.zeros((n, num_phy), dtype=np.int32) + logcnt = np.ones((n, num_log), dtype=np.int32) + arangen = np.arange(n, dtype=np.int32) + for i in range(num_log, num_phy): + redundant_indices = np.argmax(weight / logcnt, axis=-1) + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def rebalance_experts_intra_node( + weight: np.ndarray, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + num_redundant_experts = num_physical_experts - num_logical_experts + assert num_redundant_experts >= 0 + + assert num_gpus % num_nodes == 0 + num_gpus_per_node = num_gpus // num_nodes + + assert num_physical_experts % num_gpus == 0 + num_physical_experts_per_gpu = num_physical_experts // num_gpus + assert num_physical_experts % num_nodes == 0 + num_physical_experts_per_node = num_physical_experts // num_nodes + + assert num_logical_experts % num_physical_experts_per_node == 0 + # num_logical_nodes = num_logical_experts // num_physical_experts_per_node + assert num_redundant_experts % num_physical_experts_per_node == 0 + # num_redundant_nodes = num_redundant_experts // num_physical_experts_per_node + + def inverse(perm: np.ndarray) -> np.ndarray: + inv = np.empty_like(perm) + inv[np.arange(perm.shape[0])[:, None], perm] = np.arange(perm.shape[1], dtype=np.int32).reshape(1, -1) + return inv + + # Step 1: generate redundant experts by weight. + # shape of tmp2log, tmprank is [num_layers, num_physical_experts] + # shape of logcnt is [num_layers, num_logical_experts] + tmp2log, tmprank, logcnt = replicate_experts(weight, num_physical_experts) + + # Step 2: compute num_tokens of physical experts + # shape of tokens_per_tmp is [num_layers * num_nodes, num_physical_experts_per_node] + tokens_per_tmp = np.take_along_axis(weight / logcnt, tmp2log, axis=-1).reshape(-1, num_physical_experts_per_node) + + # STEP 3: take load balance of gpu cards in node + # shape of gpu_index, rank_in_gpu, tmp2phy, phy2tmp is [num_layers * num_nodes, num_physical_experts_per_node] + gpu_index, rank_in_gpu = balanced_packing(tokens_per_tmp, num_gpus_per_node) + tmp2phy = gpu_index * num_physical_experts_per_gpu + rank_in_gpu + phy2tmp = inverse(tmp2phy) + + # STEP 4: generate final phy2log mapping + tmp2log = tmp2log.reshape(-1, num_physical_experts_per_node) + tmprank = tmprank.reshape(-1, num_physical_experts_per_node) + phy2log = np.take_along_axis(tmp2log, phy2tmp, axis=-1).reshape(-1, num_physical_experts) + phyrank = np.take_along_axis(tmprank, phy2tmp, axis=-1).reshape(-1, num_physical_experts) + return phy2log, phyrank, logcnt + + +def rebalance_experts_hierarchical( + weight: np.ndarray, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: np.ndarray) -> np.ndarray: + inv = np.empty_like(perm) + inv[np.arange(perm.shape[0])[:, None], perm] = np.arange(perm.shape[1], dtype=np.int32).reshape(1, -1) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(axis=-1) + group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) + log2mlog = ( + ((group_pack_index * groups_per_node + group_rank_in_pack) * group_size)[:, :, None] + + np.arange(group_size, dtype=np.int32) + ).reshape(num_layers, -1) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=-1).reshape(-1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) + + # Step 3: pack physical_experts to GPUs + tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=-1) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=-1) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = ( + pphy2mlog.reshape(num_layers, num_nodes, -1) + + np.arange(0, num_logical_experts, num_logical_experts // num_nodes, dtype=np.int32).reshape(1, -1, 1) + ).reshape(num_layers, -1) + pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=-1) + pphyrank = np.take_along_axis(phyrank, pphy2phy, axis=-1).reshape(num_layers, -1) + logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=-1) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts( + weight: np.ndarray, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + eplb_strategy: str = "", +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Entry point for expert-parallelism load balancer. + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.astype(np.float32) + if eplb_strategy == "balance_intra_node": + phy2log, phyrank, logcnt = rebalance_experts_intra_node(weight, num_replicas, num_groups, num_nodes, num_gpus) + else: + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas) + maxlogcnt = logcnt.max() + log2phy = np.full((num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int32) + np.put_along_axis( + log2phy.reshape(num_layers, -1)[:, :, None], + (phy2log * maxlogcnt + phyrank)[:, :, None], + np.arange(num_replicas, dtype=np.int32).reshape(1, -1).repeat(num_layers, axis=0)[:, :, None], + axis=1, + ) + return phy2log, log2phy, logcnt + + +__all__ = ["rebalance_experts"] + + +def main(): + """ + main + """ + num_hidden_layers = 3 + num_expert = 64 + num_groups = 8 + + num_replicas = 64 + num_nodes = 4 + num_gpus = 4 * 8 + + model_tokens_per_expert_stats_list = np.random.randint(low=1, high=10, size=(num_hidden_layers, num_expert)) + + phy2log, phyrank, logcnt = rebalance_experts( + model_tokens_per_expert_stats_list, + num_replicas, + num_groups, + num_nodes, + num_gpus, + ) + print(phy2log) + print(phyrank) + print(logcnt) + + +if __name__ == "__main__": + main() diff --git a/fastdeploy/eplb/experts_manager.py b/fastdeploy/eplb/experts_manager.py new file mode 100644 index 000000000..c8a2ea197 --- /dev/null +++ b/fastdeploy/eplb/experts_manager.py @@ -0,0 +1,503 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import threading +import time +from http import HTTPStatus +from multiprocessing import Pipe, Process + +import numpy as np +import requests + +from fastdeploy.config import FDConfig +from fastdeploy.eplb.async_expert_loader import load_model_weights_process +from fastdeploy.eplb.eplb import rebalance_experts +from fastdeploy.eplb.utils import RedundantExpertWorkload +from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus +from fastdeploy.utils import get_logger + + +class RedundantExpertManager: + """ + RedundantExpertManger + """ + + def __init__( + self, + rank: int = 0, + ep_size: int = 32, + fd_config: FDConfig = None, + ipc_signal_suffix: int = 0, + ): + self.logger = get_logger("eplb_expert_manager", "eplb_{0}.log".format(rank)) + + self.rank = rank + self.ep_size = ep_size + self.fd_config = fd_config + self.eplb_config = fd_config.eplb_config + self.api_user = self.eplb_config.redundant_expert_api_user + self.api_passwd = self.eplb_config.redundant_expert_api_password + self.num_redundant_experts = self.eplb_config.redundant_experts_num + self.num_hidden_layers = self.fd_config.model_config.num_hidden_layers + self.num_logical_experts = self.fd_config.model_config.moe_num_experts + self.ipc_signal_suffix = ipc_signal_suffix + + self.num_replicas = self.num_logical_experts + self.num_redundant_experts + self.num_groups = self.num_logical_experts + self.num_nodes = max(ep_size // 8, 1) + self.num_gpus = ep_size + self.expert_per_rank = self.num_replicas // ep_size + assert ( + self.num_replicas % ep_size == 0 + ), f"num_replicas must be divisible by ep_size, \ + but got num_replicas = {self.num_replicas}, ep_size = {ep_size}" + + self.model_ep_rank_to_expert_id_list = np.full( + ( + self.num_hidden_layers, + self.num_logical_experts + self.num_redundant_experts, + ), + -1, + dtype=np.int32, + ) + self.model_expert_id_to_ep_rank_array = np.full( + ( + self.num_hidden_layers, + self.num_logical_experts, + self.num_redundant_experts + 1, + ), + -1, + dtype=np.int32, + ) + self.model_expert_in_rank_num_list = np.zeros( + (self.num_hidden_layers, self.num_logical_experts), dtype=np.int32 + ) + + # backup info + self.last_model_ep_rank_to_expert_id_list = np.full( + ( + self.num_hidden_layers, + self.num_logical_experts + self.num_redundant_experts, + ), + -1, + dtype=np.int32, + ) + self.last_model_expert_id_to_ep_rank_array = np.full( + ( + self.num_hidden_layers, + self.num_logical_experts, + self.num_redundant_experts + 1, + ), + -1, + dtype=np.int32, + ) + self.last_model_expert_in_rank_num_list = np.zeros( + (self.num_hidden_layers, self.num_logical_experts), dtype=np.int32 + ) + + self.model_tokens_per_expert_stats_list = np.ones( + (self.num_hidden_layers, self.num_logical_experts), dtype=np.int32 + ) + self.caculate_expert_rank_table(True) + + self.dp_rank_address = None + self.need_allgather_load_weight_result = False + self.load_weight_begin_ts = 0 + self.load_weight_timeout = 300 # 5min + self.need_rearrange_expert = False + self.need_update_expert_tokens_stat = True + self.http_timeout = 1 + # 重置重排状态: 'done' -> 'free' + self.rearrange_end_ts = 0 + self.rearrange_reset_interval = 300 + + self.tensor_infos = None + + self.parent_data_conn, child_data_conn = Pipe() + self.parent_mg_conn, child_mg_conn = Pipe() + Process( + target=load_model_weights_process, + name=f"eplb::async_load_model_{rank}", + args=( + self.rank, + self.fd_config.model_config.model, + self.expert_per_rank, + self.fd_config.model_config.moe_layer_start_index, + self.eplb_config.moe_quant_type, + self.ipc_signal_suffix, + self.eplb_config, + child_data_conn, + child_mg_conn, + ), + ).start() + child_data_conn.close() + child_mg_conn.close() + + listen_signal_thread = threading.Thread(target=self.listen_rearrange_expert_signal, args=(), daemon=True) + listen_signal_thread.start() + + self.logger.info( + f"redundant_expert: RedundantExpertManager init success, rank {rank}, \ + strategy {self.eplb_config.redundant_expert_eplb_strategy}" + ) + + # def get_unique_name(self, name): + # return f"{envs.get_unique_name(name + '_dprank_' + str(self.rank))}" + + def get_ep_rank_to_expert_id_list(self): + """ + get_ep_rank_to_expert_id_list + """ + return ( + self.model_ep_rank_to_expert_id_list, + self.model_expert_id_to_ep_rank_array, + self.model_expert_in_rank_num_list, + ) + + def listen_rearrange_expert_signal(self): + """ + listen_rearrange_expert_signal + """ + if self.rank == 0: + rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32) + rearrange_experts_ips_size_signal = IPCSignal( + name="rearrange_experts_ips_size", + array=rearrange_experts_ips_size_array, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=False, + ) + + shm_rearrange_experts_ips_list = IPCSignal( + name="rearrange_experts_ips_list", + shm_size=self.eplb_config.redundant_expert_ip_shm_size, + suffix=self.ipc_signal_suffix, + create=False, + ) + + rearrange_experts_status = np.zeros([1], dtype=np.int32) + rearrange_experts_signal = IPCSignal( + name="rearrange_experts_status", + array=rearrange_experts_status, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=False, + ) + + signal_update_weight_from_disk = np.zeros([1], dtype=np.int32) + signal_update_weight_from_disk_array = IPCSignal( + name="signal_update_weight_from_disk", + array=signal_update_weight_from_disk, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=False, + ) + + experts_token_stats = np.zeros( + (self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.moe_num_experts), + dtype=np.int32, + ) + shm_all_experts_token_stats = IPCSignal( + name="all_experts_token_stats", + array=experts_token_stats, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=False, + ) + + while True: + if self.rank == 0: + now = int(time.time()) + if rearrange_experts_ips_size_signal.value[0] > 0: + # step 1. all reduce experts token stats + address = bytes( + shm_rearrange_experts_ips_list.shm.buf[: rearrange_experts_ips_size_signal.value[0]] + ).decode("utf-8") + self.logger.info(f"redundant_expert: all rank ips {address}") + rearrange_experts_ips_size_signal.value[0] = 0 + rearrange_experts_signal.value[0] = RearrangeExpertStatus.DOING.value + + self.dp_rank_address = address.strip().split(";") + if self.allreduce_experts_stat(): + self.need_allgather_load_weight_result = True + self.load_weight_begin_ts = now + self.logger.info("redundant_expert: all-reduce experts stats success") + else: + rearrange_experts_signal.value[0] = RearrangeExpertStatus.FREE.value + self.logger.warning("redundant_expert: all-reduce experts stats fail") + elif self.need_allgather_load_weight_result and self.allreduce_load_weight_result(): + # step 3. all reduce the result of load weight from disk + self.need_allgather_load_weight_result = False + rearrange_experts_signal.value[0] = RearrangeExpertStatus.LOAD_SUCC.value + self.rearrange_end_ts = now + if rearrange_experts_signal.value[0] > 1 and ( + now - self.rearrange_end_ts > self.rearrange_reset_interval + ): + # reset rearrange status + rearrange_experts_signal.value[0] = RearrangeExpertStatus.FREE.value + + if signal_update_weight_from_disk_array.value[0] == 1: + # step 2. async load weight: disk -> memory + self.model_tokens_per_expert_stats_list[:] = shm_all_experts_token_stats.value[:] + self.caculate_expert_rank_table() + self.update_weight_from_disk() + signal_update_weight_from_disk_array.value[0] = 0 + time.sleep(0.5) + + def caculate_expert_rank_table(self, is_init=False): + """ + caculate_expert_rank_table + """ + num_groups = self.num_groups + num_nodes = self.num_nodes + num_gpus = self.num_gpus + eplb_strategy = self.eplb_config.redundant_expert_eplb_strategy + if is_init: + num_groups = 1 + num_nodes = 2 + num_gpus = 2 * 8 + eplb_strategy = "" + # eplb + rank_expert_list, logical_to_physical_map, expert_count = rebalance_experts( + self.model_tokens_per_expert_stats_list, + self.num_replicas, + num_groups, + num_nodes, + num_gpus, + eplb_strategy, + ) + + # backup info + self.last_model_ep_rank_to_expert_id_list[:] = self.model_ep_rank_to_expert_id_list[:] + self.last_model_expert_id_to_ep_rank_array[:] = self.model_expert_id_to_ep_rank_array[:] + self.last_model_expert_in_rank_num_list[:] = self.model_expert_in_rank_num_list[:] + + # update model info + self.model_ep_rank_to_expert_id_list[:] = rank_expert_list[:] + self.model_expert_id_to_ep_rank_array.fill(-1) + self.model_expert_id_to_ep_rank_array[..., : logical_to_physical_map.shape[-1]] = logical_to_physical_map[:] + self.model_expert_in_rank_num_list[:] = expert_count[:] + + if self.rank == 0: + workload = RedundantExpertWorkload() + workload.tokens_per_expert_stats_list = self.model_tokens_per_expert_stats_list.tolist() + workload.ep_rank_to_expert_id_list = rank_expert_list.tolist() + workload.expert_id_to_ep_rank_array = logical_to_physical_map.tolist() + workload.expert_in_rank_num_list = expert_count.tolist() + self.logger.info(workload.dump()) + + def update_weight_from_disk(self): + """ + update_weight_from_disk + """ + begin_time = time.time() + result_update_weight_from_disk = np.zeros([1], dtype=np.int32) + update_weight_from_disk_result = IPCSignal( + name="result_update_weight_from_disk", + array=result_update_weight_from_disk, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=False, + ) + update_weight_from_disk_result.value[0] = 0 + + self.logger.info(f"redundant_expert: update_weight_from_disk send to async process, rank {self.rank}") + self.parent_mg_conn.send( + { + "old_model_ep_rank_to_expert_id_list": self.last_model_ep_rank_to_expert_id_list, + "new_model_ep_rank_to_expert_id_list": self.model_ep_rank_to_expert_id_list, + } + ) + self.logger.info(f"redundant_expert: update_weight_from_disk recv from async process, rank {self.rank}") + response = self.parent_data_conn.recv() + self.tensor_infos = response["weights"] + + # 更新权重加载结果 + update_weight_from_disk_result.value[0] = 1 if response["result"] else -1 + self.logger.info( + "redundant_expert: update_weight_from_disk end, rank" + + f" {self.rank} {response['result']}, cost {int(time.time() - begin_time)}s" + ) + + def allreduce_experts_stat(self): + """ + 专家负载 + """ + if not self.allgather_expert_token_stats(): + return False + return self.broadcast_expert_token_stats() + + def allgather_expert_token_stats(self): + """ + allgather_expert_token_stats + """ + expert_token_stats = np.zeros((self.num_hidden_layers, self.num_logical_experts), dtype=np.int32) + success_count = 0 + for addr in self.dp_rank_address: + try: + # TODO: 请求失败重试 + params = {"user": self.api_user, "passwd": self.api_passwd} + res = requests.post( + f"http://{addr}/get_per_expert_tokens_stats", + json=params, + timeout=self.http_timeout, + ) + if res.status_code != HTTPStatus.OK: + self.logger.warning( + "redundant_expert: allgather_expert_token_stats fail. " + + f"addr {addr}, res {res.status_code} {res.json()}" + ) + break + + for meta_data in res.json()["data"]: + expert_token_stats += np.array(meta_data, dtype=np.int32) + success_count += 1 + except Exception as e: + self.logger.error(f"redundant_expert: allgather_expert_token_stats fail. addr {addr}, error {e}") + if success_count == len(self.dp_rank_address): + self.need_rearrange_expert = True + self.model_tokens_per_expert_stats_list[:] = expert_token_stats[:] + self.logger.info("redundant_expert: allgather_expert_token_stats success") + return True + self.logger.info( + "redundant_expert: allgather_expert_token_stats fail. " + + f"succ {success_count} total {len(self.dp_rank_address)}" + ) + return False + + def broadcast_expert_token_stats(self): + """ + broadcast_expert_token_stats + """ + success_count = 0 + for addr in self.dp_rank_address: + try: + params = { + "user": self.api_user, + "passwd": self.api_passwd, + "action": "recv_expert_weight", + "data": self.model_tokens_per_expert_stats_list.tolist(), + } + res = requests.post( + f"http://{addr}/rearrange_experts", + json=params, + timeout=self.http_timeout, + ) + if res.status_code != HTTPStatus.OK: + self.logger.warning( + "redundant_expert: broadcast_expert_token_stats fail. " + + f"addr {addr}, res {res.status_code} {res.json()}" + ) + break + success_count += 1 + except Exception as e: + self.logger.error( + f"redundant_expert: broadcast_expert_token_stats request fail. addr {addr}, error {e}" + ) + if success_count == len(self.dp_rank_address): + self.logger.info("redundant_expert: broadcast_expert_token_stats success") + return True + self.logger.info( + "redundant_expert: broadcast_expert_token_stats failed, " + + f"succ {success_count} total {len(self.dp_rank_address)}" + ) + return False + + def allreduce_load_weight_result(self): + """ + 权重加载结果 + """ + if int(time.time()) - self.load_weight_begin_ts > self.load_weight_timeout: + self.logger.info(f"redundant_expert: allreduce_load_weight_result timeout {self.load_weight_timeout}s") + return True + + all_success, exist_fail = self.allgather_load_weight_result() + if exist_fail: + # 如果有DP权重加载异常,结束本次重排 + self.logger.warning("redundant_expert: allreduce_load_weight_result exist fail, terminate this rearrange") + return True + if not all_success: + self.logger.info("redundant_expert: allreduce_load_weight_result waiting") + return False + # self.broadcast_load_weight_success() + if not exist_fail and all_success: + # prefill需要等待调度屏蔽 + if ( + self.fd_config.splitwise_role == "decode" + or not self.eplb_config.redundant_expert_enable_schedule_cordon + ): + self.logger.info("redundant_expert: allreduce_load_weight_result success, notify infer.py") + signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) + signal_update_weight_from_tensor_array = IPCSignal( + name="signal_update_weight_from_tensor", + array=signal_update_weight_from_tensor, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=False, + ) + signal_update_weight_from_tensor_array.value[0] = 1 + return True + + def allgather_load_weight_result(self): + """ + allgather_load_weight_result + """ + all_success, exist_fail = False, False + + success_count, fail_count = 0, 0 + for addr in self.dp_rank_address: + try: + params = { + "user": self.api_user, + "passwd": self.api_passwd, + "action": "check_load_weight_result", + } + res = requests.post( + f"http://{addr}/check_redundant", + json=params, + timeout=self.http_timeout, + ) + if res.status_code != HTTPStatus.OK: + self.logger.warning( + "redundant_expert: allgather_load_weight_result fail. " + + f"addr {addr}, res {res.status_code} {res.json()}" + ) + break + result_list = res.json()["data"] + self.logger.info( + f"redundant_expert: allgather_load_weight_result success. addr {addr}, result_list {result_list}" + ) + for result in result_list: + if result == 1: + success_count += 1 + elif result == -1: + fail_count += 1 + self.logger.error( + f"redundant_expert: allgather_load_weight_result fail. addr {addr}, result {result}" + ) + exist_fail = True + except Exception as e: + self.logger.error(f"redundant_expert: allgather_load_weight_result error. addr {addr}, error {e}") + + if fail_count > 0: + self.logger.info( + "redundant_expert: allgather_load_weight_result not all ready, " + + f"succ {success_count} fail {fail_count} total {len(self.dp_rank_address)}" + ) + else: + self.logger.info("redundant_expert: allgather_load_weight_result all success") + all_success = True + return all_success, exist_fail diff --git a/fastdeploy/eplb/utils.py b/fastdeploy/eplb/utils.py new file mode 100644 index 000000000..a4691b6fd --- /dev/null +++ b/fastdeploy/eplb/utils.py @@ -0,0 +1,160 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import os +import time + +import numpy as np + +from fastdeploy.config import FDConfig +from fastdeploy.inter_communicator import IPCSignal + + +class RedundantExpertWorkload: + """Redundant Expert Workload""" + + def __init__(self, redundant_expert_meta_dir="/tmp/redundant_expert_meta"): + self.update_timestamp = time.time() + self.tokens_per_expert_stats_list = None + self.ep_rank_to_expert_id_list = None + self.expert_id_to_ep_rank_array = None + self.expert_in_rank_num_list = None + self.cost_milliseconds = 0 + self.meta_file_name = f"{redundant_expert_meta_dir}/rearrange-experts.json" + if not os.path.exists(redundant_expert_meta_dir): + os.makedirs(redundant_expert_meta_dir, exist_ok=True) + + def __json__(self): + return self.__dict__ + + def dump(self): + """Dump the object to a JSON file.""" + begin = time.time() + try: + with open(self.meta_file_name, "w") as fout: + json.dump(self.__dict__, fout) + except Exception as e: + return f"redundant_expert: dump expert workload failed, {e}" + cost_time = int((time.time() - begin) * 1000 * 1000) + return f"redundant_expert: dump expert workload result in {cost_time} us" + + def load(self): + """Load the object from a JSON file.""" + if not os.path.exists(self.meta_file_name): + return {}, f"redundant_expert: file {self.meta_file_name} is not exists" + try: + with open(self.meta_file_name, "r") as fin: + meta = json.load(fin) + self.__dict__.update(meta) + return self.__json__(), "ok" + except Exception as e: + return {}, f"redundant_expert: load file {self.meta_file_name} failed, {e}" + + +def init_eplb_signals(config: FDConfig, ipc_signal_suffix): + """ + Initialize shared memory to indicate eplb status + """ + if config.parallel_config.local_data_parallel_id == 0: + # rearrange_experts_status Record the expert's rearrangement status + rearrange_experts_array = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="rearrange_experts_status", + array=rearrange_experts_array, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=True, + ) + + # Record all DP rank IPs when receiving expert rearrangement requests + rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="rearrange_experts_ips_size", + array=rearrange_experts_ips_size_array, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=True, + ) + _ = IPCSignal( + name="rearrange_experts_ips_list", + shm_size=config.eplb_config.redundant_expert_ip_shm_size, + suffix=ipc_signal_suffix, + create=True, + ) + + # Receive signals for updating weights + signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="signal_update_weight_from_tensor", + array=signal_update_weight_from_tensor, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=True, + ) + + # Record expert workload + experts_token_stats = np.zeros( + (config.model_config.num_hidden_layers, config.model_config.moe_num_experts), + dtype=np.int32, + ) + _ = IPCSignal( + name="all_experts_token_stats", + array=experts_token_stats, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=True, + ) + _ = IPCSignal( + name="local_experts_token_stats", + array=experts_token_stats, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=True, + ) + + # Receive signals for loading weights + signal_update_weight_from_disk = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="signal_update_weight_from_disk", + array=signal_update_weight_from_disk, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=True, + ) + + # Receive signals for clearing expert loads + clear_experts_token_stats = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="signal_clear_experts_token_stats", + array=clear_experts_token_stats, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=True, + ) + + result_update_weight_from_disk = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="result_update_weight_from_disk", + array=result_update_weight_from_disk, + dtype=np.int32, + suffix=ipc_signal_suffix, + create=True, + ) + + +if __name__ == "__main__": + print(RedundantExpertWorkload("/tmp").load()) diff --git a/fastdeploy/inter_communicator/__init__.py b/fastdeploy/inter_communicator/__init__.py index 373702edb..5e767562f 100644 --- a/fastdeploy/inter_communicator/__init__.py +++ b/fastdeploy/inter_communicator/__init__.py @@ -17,6 +17,7 @@ from .engine_cache_queue import EngineCacheQueue from .engine_worker_queue import EngineWorkerQueue from .ipc_signal import IPCSignal, shared_memory_exists +from .ipc_signal_const import RearrangeExpertStatus from .zmq_client import ZmqIpcClient from .zmq_server import ZmqIpcServer, ZmqTcpServer @@ -28,4 +29,5 @@ __all__ = [ "ZmqTcpServer", "ZmqIpcServer", "shared_memory_exists", + "RearrangeExpertStatus", ] diff --git a/fastdeploy/inter_communicator/ipc_signal.py b/fastdeploy/inter_communicator/ipc_signal.py index 075f1a461..d40d1cce9 100644 --- a/fastdeploy/inter_communicator/ipc_signal.py +++ b/fastdeploy/inter_communicator/ipc_signal.py @@ -55,10 +55,11 @@ class IPCSignal: def __init__( self, name: str, - array: np.ndarray, - dtype: np.dtype, + array: np.ndarray = None, + dtype: np.dtype = None, suffix: int = None, create: bool = True, + shm_size: int = None, ) -> None: """Initialize or connect to a shared memory block. @@ -72,23 +73,36 @@ class IPCSignal: Raises: AssertionError: If create=True but memory already exists, or dtype mismatch. """ - assert isinstance(array, np.ndarray), "Input must be a numpy array" - assert dtype == array.dtype, "Specified dtype must match array dtype" + if dtype is None or array is None: + assert shm_size is not None, "shm_size must be specified if array and dtype are None" - # Set a suffix for name to avoid name conflict while there are multiple engine launched - if suffix is not None: - name = name + f".{suffix}" - - if create: - if shared_memory_exists(name): - llm_logger.warning(f"ShareMemory: {name} already exists, delete it") - SharedMemory(name=name, create=False).unlink() - self.shm = SharedMemory(create=True, size=array.nbytes, name=name) - self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf) - self.value[:] = array # Initialize with input array data + if create: + llm_logger.debug(f"creating ipc signal: {name}") + if shared_memory_exists(name): + llm_logger.warning(f"ShareMemory: {name} already exists, delete it") + SharedMemory(name=name, create=False).unlink() + self.shm = SharedMemory(create=True, size=shm_size, name=name) + else: + llm_logger.debug(f"attaching ipc signal: {name}") + self.shm = SharedMemory(name=name) else: - self.shm = SharedMemory(name=name) - self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf) + assert isinstance(array, np.ndarray), "Input must be a numpy array" + assert dtype == array.dtype, "Specified dtype must match array dtype" + + # Set a suffix for name to avoid name conflict while there are multiple engine launched + if suffix is not None: + name = name + f".{suffix}" + + if create: + if shared_memory_exists(name): + llm_logger.warning(f"ShareMemory: {name} already exists, delete it") + SharedMemory(name=name, create=False).unlink() + self.shm = SharedMemory(create=True, size=array.nbytes, name=name) + self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf) + self.value[:] = array # Initialize with input array data + else: + self.shm = SharedMemory(name=name) + self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf) def clear(self) -> None: """Release system resources and unlink the shared memory block.""" diff --git a/fastdeploy/inter_communicator/ipc_signal_const.py b/fastdeploy/inter_communicator/ipc_signal_const.py new file mode 100644 index 000000000..134a32496 --- /dev/null +++ b/fastdeploy/inter_communicator/ipc_signal_const.py @@ -0,0 +1,26 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from dataclasses import dataclass +from enum import Enum + + +@dataclass +class RearrangeExpertStatus(Enum): + FREE = 0 + DOING = 1 + LOAD_SUCC = 2 # load weight from disk success + DONE = 3 diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index ca860bb30..96758472d 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -388,7 +388,7 @@ class Ernie4_5_Model(nn.Layer): fd_config.model_config.pretrained_config.prefix_name = "ernie" self.fd_config = fd_config self.redundant_table_manger = None - if fd_config.model_config.enable_redundant_experts is True: + if fd_config.eplb_config.enable_eplb is True: self.redundant_table_manger = RedundantExpertManger( n_routed_experts=fd_config.model_config.moe_num_experts, num_hidden_layers=fd_config.model_config.num_hidden_layers, diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 967c12ca3..940d37821 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -66,6 +66,7 @@ class RolloutModelConfig: num_nextn_predict_layers: int = 0, enable_attention_dp_balance: bool = False, attention_dp_time_out_iters: int = 0, + eplb_config: str = {}, ): # Required parameters self.model = model_name_or_path @@ -115,6 +116,7 @@ class RolloutModelConfig: self.num_nextn_predict_layers = num_nextn_predict_layers self.enable_attention_dp_balance = enable_attention_dp_balance self.attention_dp_time_out_iters = attention_dp_time_out_iters + self.eplb_config = eplb_config def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index e83bd9350..fb312db24 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -30,6 +30,7 @@ from fastdeploy.config import ( DecodingConfig, DeviceConfig, EarlyStopConfig, + EPLBConfig, ErnieArchitectures, FDConfig, GraphOptimizationConfig, @@ -40,9 +41,16 @@ from fastdeploy.config import ( SpeculativeConfig, ) from fastdeploy.engine.request import RequestType +from fastdeploy.eplb.async_expert_loader import ( + MODEL_MAIN_NAME, + REARRANGE_EXPERT_MAGIC_NUM, + create_mmap, + load_tensor_from_shm_mem, +) +from fastdeploy.eplb.experts_manager import RedundantExpertManager from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue -from fastdeploy.inter_communicator import IPCSignal +from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus from fastdeploy.model_executor.layers.quantization import get_quantization_config from fastdeploy.platforms import current_platform from fastdeploy.utils import get_logger, parse_quantization @@ -151,6 +159,7 @@ class PaddleDisWorkerProc: self.fd_config = fd_config self.parallel_config = fd_config.parallel_config self.cache_config = fd_config.cache_config + self.eplb_config = fd_config.eplb_config # TODO(gongshaotian): Use worker factory to get worker self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks) @@ -249,6 +258,18 @@ class PaddleDisWorkerProc: create=False, ) + def update_weights_from_tensor(self, mmap_infos): + """ + update_weights_from_tensor + """ + state_dicts = load_tensor_from_shm_mem(self.experts_manager.tensor_infos, mmap_infos[MODEL_MAIN_NAME], logger) + rank_expert_list, logical_to_physical_map, expert_count = self.experts_manager.get_ep_rank_to_expert_id_list() + self.worker.get_model().redundant_table_manger.update_expert_rank_table( + rank_expert_list, logical_to_physical_map, expert_count + ) + # TO BE FIXED + self.worker.get_model().update_state_dict(state_dicts) + def _broadcast_model_weights_signal(self, src: int, group) -> int: model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32") paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group) @@ -258,6 +279,63 @@ class PaddleDisWorkerProc: """Main event loop for Paddle Distrubuted Workers. TODO(gongshaotian): support remote calling of functions that control worker. """ + if self.eplb_config.enable_eplb: + self.last_dump_expert_workload_ts = 0 + self.experts_manager = RedundantExpertManager( + rank=self.local_rank, + ep_size=self.ranks, + fd_config=self.fd_config, + ipc_signal_suffix=self.parallel_config.engine_worker_queue_port, + ) + experts_token_stats = np.zeros( + (self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.moe_num_experts), + dtype=np.int32, + ) + local_experts_token_stats_array = IPCSignal( + name="local_experts_token_stats", + array=experts_token_stats, + dtype=np.int32, + suffix=self.parallel_config.engine_worker_queue_port, + create=False, + ) + + clear_experts_token_stats = np.zeros([1], dtype=np.int32) + signal_clear_experts_token_stats = IPCSignal( + name="signal_clear_experts_token_stats", + array=clear_experts_token_stats, + dtype=np.int32, + suffix=self.parallel_config.engine_worker_queue_port, + create=False, + ) + + if self.local_rank == 0: + signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) + signal_update_weight_from_tensor_array = IPCSignal( + name="signal_update_weight_from_tensor", + array=signal_update_weight_from_tensor, + dtype=np.int32, + suffix=self.parallel_config.engine_worker_queue_port, + create=False, + ) + + rearrange_experts_status = np.zeros([1], dtype=np.int32) + rearrange_experts_signal = IPCSignal( + name="rearrange_experts_status", + array=rearrange_experts_status, + dtype=np.int32, + suffix=self.parallel_config.engine_worker_queue_port, + create=False, + ) + + mmap_infos = create_mmap( + [MODEL_MAIN_NAME], + self.local_rank, + self.ranks, + shm_uuid=self.parallel_config.engine_worker_queue_port, + eplb_config=self.eplb_config, + logger=logger, + ) + # Currently, only support single node self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8) req_ids = [] @@ -267,6 +345,45 @@ class PaddleDisWorkerProc: attention_dp_cached_prefill_tasks = [] attention_dp_wait_prefill_iters = 0 while True: + if self.eplb_config.enable_eplb: + rearrange_time = time.time() + # 获取专家负载 + if local_experts_token_stats_array.value is not None and ( + int(rearrange_time) - self.last_dump_expert_workload_ts + > self.eplb_config.redundant_expert_dump_workload_interval + ): + self.last_dump_expert_workload_ts = int(rearrange_time) + clear_stat = False + if signal_clear_experts_token_stats.value[0] == 1: + clear_stat = True + signal_clear_experts_token_stats.value[0] = 0 + ( + new_stats_array, + _, + _, + _, + ) = self.worker.get_model().redundant_table_manger.get_expert_tokens_stats(clear_stat=clear_stat) + local_experts_token_stats_array.value[:] = new_stats_array[:] + elif local_experts_token_stats_array.value is None: + logger.warning("redundant_expert: local_experts_token_stats not init") + + # 所有DP同步更新权重 + broadcast_value = 0 + if self.local_rank == 0 and signal_update_weight_from_tensor_array.value[0] == 1: + logger.info("redundant_expert: update_weight_from_tensor broadcast signal") + signal_update_weight_from_tensor_array.value[0] = 0 + broadcast_value = REARRANGE_EXPERT_MAGIC_NUM + data = paddle.to_tensor([broadcast_value]) + paddle.distributed.broadcast(data, 0) + if data[0] == REARRANGE_EXPERT_MAGIC_NUM: + self.update_weights_from_tensor(mmap_infos) + logger.info( + f"redundant_expert: update_weight_from_tensor success, cost {(time.time() - rearrange_time)*1000}ms" + ) + paddle.distributed.barrier() + if self.local_rank == 0: + rearrange_experts_signal.value[0] = RearrangeExpertStatus.DONE.value + logger.info("redundant_expert: done") if local_rank == 0: if self.model_weights_status.value[0] != 0: self.model_weights_signal[0] = int(self.model_weights_status.value[0]) @@ -706,6 +823,13 @@ def parse_args(): help="max waiting steps to sync all dp for prefill tasks available", ) + parser.add_argument( + "--eplb_config", + type=json.loads, + default=None, + help="EPLB Configuration.", + ) + args = parser.parse_args() return args @@ -764,6 +888,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: early_stop_config = EarlyStopConfig(args.early_stop_config) + eplb_config = EPLBConfig(args.eplb_config) + # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size @@ -861,6 +987,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: moba_attention_config=moba_attention_config, enable_attention_dp_balance=args.enable_attention_dp_balance, attention_dp_time_out_iters=args.attention_dp_time_out_iters, + eplb_config=eplb_config, ) update_fd_config_for_mm(fd_config) diff --git a/requirements.txt b/requirements.txt index fd0f295fa..d4e9a71f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,3 +40,5 @@ opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi partial_json_parser einops +cuda-python==12.8 +setproctitle