[Feature] Support eplb for ep (#4786)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* support eplb for ep

* update code

* update code

* update code

* update code

* update code

* update code

* update code

* update code

* update code
This commit is contained in:
kevin
2025-11-07 15:42:29 +08:00
committed by GitHub
parent bbae094cb9
commit 3dbe5596e6
18 changed files with 2048 additions and 25 deletions

View File

@@ -124,7 +124,6 @@ class ModelConfig:
self.max_model_len = 0
self.dtype = ""
self.enable_logprob = False
self.enable_redundant_experts = False
self.redundant_experts_num = 0
self.seed = 0
self.quantization = None
@@ -247,6 +246,60 @@ class ModelConfig:
logger.info("=============================================================")
class EPLBConfig:
"""
Configuration for EPLB manager.
"""
def __init__(
self,
args,
):
# enable eplb
self.enable_eplb: bool = False
# redundant experts num
self.redundant_experts_num: int = 0
# expert ip shm size
self.redundant_expert_ip_shm_size: int = 1024
# expert meta dir
self.redundant_expert_meta_dir: str = "/tmp/redundant_expert_meta"
# expert api user and password
self.redundant_expert_api_user: str = ""
self.redundant_expert_api_password: str = ""
# expert eplb strategy
self.redundant_expert_eplb_strategy: str = ""
# expert dump workload interval
self.redundant_expert_dump_workload_interval: int = 10
# expert async load model shmem size gb
self.redundant_expert_async_load_model_shmem_size_gb: int = 0
# expert enable schedule cordon
self.redundant_expert_enable_schedule_cordon: bool = True
# model use safetensors
self.model_use_safetensors: bool = True
# model use offline quant
self.model_use_offline_quant: bool = True
# moe quant type
self.moe_quant_type: str = "w4a8"
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
def to_json_string(self):
"""
Convert eplb_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def print(self):
"""
Print all configuration information.
"""
logger.info("EPLB Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class ParallelConfig:
"""Configuration for the distributed execution."""
@@ -1141,6 +1194,7 @@ class FDConfig:
reasoning_parser: str = None,
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
eplb_config: EPLBConfig = None,
early_stop_config: Optional[Dict[str, Any]] = None,
tool_parser: str = None,
test_mode=False,
@@ -1159,6 +1213,7 @@ class FDConfig:
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
self.decoding_config: DecodingConfig = decoding_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
self.eplb_config: Optional[EPLBConfig] = eplb_config
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
self.enable_attention_dp_balance = enable_attention_dp_balance
self.attention_dp_time_out_iters = attention_dp_time_out_iters
@@ -1386,6 +1441,7 @@ class FDConfig:
or k == "scheduler_config"
or k == "parallel_config"
or k == "commit_config"
or k == "eplb_config"
):
if v is not None:
v.print()

View File

@@ -26,6 +26,7 @@ from fastdeploy import envs
from fastdeploy.config import (
CacheConfig,
EarlyStopConfig,
EPLBConfig,
FDConfig,
GraphOptimizationConfig,
LoadConfig,
@@ -397,6 +398,15 @@ class EngineArgs:
Max waiting steps to sync all dp for prefill tasks available
"""
enable_eplb: bool = False
"""
Flag to enable eplb
"""
eplb_config: Optional[Dict[str, Any]] = None
"""
Configuration for eplb.
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -693,6 +703,18 @@ class EngineArgs:
default=EngineArgs.enable_expert_parallel,
help="Enable expert parallelism.",
)
parallel_group.add_argument(
"--enable-eplb",
action="store_true",
default=EngineArgs.enable_eplb,
help="Enable eplb.",
)
model_group.add_argument(
"--eplb-config",
type=json.loads,
default=EngineArgs.eplb_config,
help="Config of eplb.",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")
@@ -1022,7 +1044,17 @@ class EngineArgs:
early_stop_args[k] = v
return EarlyStopConfig(early_stop_args)
def create_engine_config(self) -> FDConfig:
def create_eplb_config(self) -> EPLBConfig:
"""
Create and retuan an EPLBConfig object based on the current settings.
"""
eplb_args = asdict(self)
if self.eplb_config is not None:
for k, v in self.eplb_config.items():
eplb_args[k] = v
return EPLBConfig(eplb_args)
def create_engine_config(self, port_availability_check: bool = True) -> FDConfig:
"""
Create and return a Config object based on the current settings.
"""
@@ -1063,6 +1095,7 @@ class EngineArgs:
graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
moba_attention_config = self.create_moba_attention_config()
eplb_cfg = self.create_eplb_config()
early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
@@ -1072,9 +1105,10 @@ class EngineArgs:
if isinstance(self.engine_worker_queue_port, str):
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
assert is_port_available(
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
if port_availability_check:
assert is_port_available(
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
return FDConfig(
model_config=model_cfg,
@@ -1084,6 +1118,7 @@ class EngineArgs:
load_config=load_cfg,
parallel_config=parallel_cfg,
max_model_len=self.max_model_len,
eplb_config=eplb_cfg,
max_num_seqs=self.max_num_seqs,
speculative_config=speculative_cfg,
max_num_batched_tokens=self.max_num_batched_tokens,

View File

@@ -33,6 +33,7 @@ from opentelemetry import trace
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.eplb.utils import init_eplb_signals
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
EngineCacheQueue,
@@ -132,6 +133,12 @@ class EngineSevice:
)
self._init_worker_monitor_signals()
if self.cfg.eplb_config.enable_eplb:
current_suffix = int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
init_eplb_signals(cfg, current_suffix)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(self):

View File

@@ -461,6 +461,7 @@ class LLMEngine:
f" --load_choices {self.cfg.load_config.load_choices}"
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --attention_dp_time_out_iters {self.cfg.attention_dp_time_out_iters}"
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
f" --ips {ips}"
)

View File

@@ -19,6 +19,7 @@ import os
import time
import traceback
import uuid
from http import HTTPStatus
import numpy as np
@@ -26,8 +27,9 @@ from fastdeploy import envs
from fastdeploy.config import ModelConfig
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.eplb.utils import RedundantExpertWorkload
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus, ZmqIpcClient
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
@@ -49,6 +51,7 @@ class EngineClient:
port,
limit_mm_per_prompt,
mm_processor_kwargs,
config,
# enable_mm=False,
reasoning_parser=None,
data_parallel_size=1,
@@ -59,6 +62,7 @@ class EngineClient:
):
import fastdeploy.model_executor.models # noqa: F401
self.config = config
architectures = ModelConfig({"model": model_name_or_path}).architectures[0]
self.enable_prefix_caching = enable_prefix_caching
if MultimodalRegistry.contains_model(architectures):
@@ -92,6 +96,9 @@ class EngineClient:
else:
self.is_master = False
if self.config.eplb_config.enable_eplb and self.config.parallel_config.expert_parallel_rank == 0:
self.init_eplb_signals(ipc_signal_suffix=port)
array_size = min(max_chips_per_node, tensor_parallel_size)
self.worker_healthy_live_recorded_time_array = np.zeros(shape=[array_size], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
@@ -115,6 +122,115 @@ class EngineClient:
)
self.connection_initialized = False
def init_eplb_signals(self, ipc_signal_suffix):
"""
Initialize eplb signals.
"""
self.signal_clear_experts_token_stats_list = []
self.local_experts_token_stats_array_list = []
self.expert_tokens_stats_array_list = []
self.signal_update_weight_from_disk_array_list = []
self.update_weight_from_disk_result_list = []
rearrange_experts_status = np.zeros([1], dtype=np.int32)
self.rearrange_experts_signal = IPCSignal(
name="rearrange_experts_status",
array=rearrange_experts_status,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=False,
)
rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32)
self.rearrange_experts_ips_size_signal = IPCSignal(
name="rearrange_experts_ips_size",
array=rearrange_experts_ips_size_array,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=False,
)
self.shm_rearrange_experts_ips_list = IPCSignal(
name="rearrange_experts_ips_list",
shm_size=self.config.eplb_config.redundant_expert_ip_shm_size,
suffix=ipc_signal_suffix,
create=False,
)
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
self.signal_update_weight_from_tensor_array = IPCSignal(
name="signal_update_weight_from_tensor",
array=signal_update_weight_from_tensor,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=False,
)
if envs.FD_ENABLE_MULTI_API_SERVER:
engine_worker_suffix = [
self.config.parallel_config.engine_worker_queue_port[
self.config.parallel_config.local_data_parallel_id
]
]
else:
engine_worker_suffix = self.config.parallel_config.engine_worker_queue_port
for suffix_port in engine_worker_suffix:
signal_clear_experts_token_stats = np.zeros([1], dtype=np.int32)
self.signal_clear_experts_token_stats_list.append(
IPCSignal(
name="signal_clear_experts_token_stats",
array=signal_clear_experts_token_stats,
dtype=np.int32,
suffix=suffix_port,
create=False,
)
)
signal_update_weight_from_disk = np.zeros([1], dtype=np.int32)
self.signal_update_weight_from_disk_array_list.append(
IPCSignal(
name="signal_update_weight_from_disk",
array=signal_update_weight_from_disk,
dtype=np.int32,
suffix=suffix_port,
create=False,
)
)
result_update_weight_from_disk = np.zeros([1], dtype=np.int32)
self.update_weight_from_disk_result_list.append(
IPCSignal(
name="result_update_weight_from_disk",
array=result_update_weight_from_disk,
dtype=np.int32,
suffix=suffix_port,
create=False,
)
)
experts_token_stats = np.zeros(
(self.config.model_config.num_hidden_layers, self.config.model_config.moe_num_experts),
dtype=np.int32,
)
self.expert_tokens_stats_array_list.append(
IPCSignal(
name="all_experts_token_stats",
array=experts_token_stats,
dtype=np.int32,
suffix=suffix_port,
create=False,
)
)
self.local_experts_token_stats_array_list.append(
IPCSignal(
name="local_experts_token_stats",
array=experts_token_stats,
dtype=np.int32,
suffix=suffix_port,
create=False,
)
)
def create_zmq_client(self, model, mode):
"""
Create a ZMQ client.
@@ -394,3 +510,209 @@ class EngineClient:
def check_model_weight_status(self):
return self.model_weights_status_signal.value[0] < 0
async def rearrange_experts(self, request_dict: dict):
"""
rearrange experts
Args:
request_dict (dict): request body
Returns:
tuple: response body, status code
"""
content, status_code = None, HTTPStatus.OK
eplb_config = self.config.eplb_config
if not eplb_config.enable_eplb:
content = {"code": 1, "msg": "redundant expert is disabled"}
status_code = HTTPStatus.BAD_REQUEST
return content, status_code
if (
request_dict.get("user", "") != eplb_config.redundant_expert_api_user
or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password
):
content = {"code": 1, "msg": "user or passwd is invalid"}
status_code = HTTPStatus.UNAUTHORIZED
return content, status_code
if self.config.parallel_config.expert_parallel_rank != 0:
content = {
"code": 1,
"msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0",
}
status_code = HTTPStatus.BAD_REQUEST
return content, status_code
action = request_dict.get("action", "")
api_server_logger.info(f"redundant_expert: rearrange_experts recv request, action {action}")
if action == "":
# action: start rearrange experts
# params: {'user': 'xxx', 'passwd': 'xxx', 'ips': ['10.54.99.77:8000', '10.54.99.77:8300']}
if self.rearrange_experts_signal.value[0] != RearrangeExpertStatus.FREE.value:
content = {
"code": 1,
"msg": f"rearrange is doing. actual status {self.rearrange_experts_signal.value[0]}, expect status {RearrangeExpertStatus.FREE.value}",
}
status_code = HTTPStatus.BAD_REQUEST
if "ips" not in request_dict and content is None:
content = {"code": 1, "msg": "ips in request is None"}
status_code = HTTPStatus.BAD_REQUEST
if content is not None:
return content, status_code
data_bytes = (";".join(request_dict["ips"])).encode("utf-8")
data_size = len(data_bytes)
if data_size > eplb_config.redundant_expert_ip_shm_size:
content = {
"code": 1,
"msg": f"actual ips size {data_size}, max limit {eplb_config.redundant_expert_ip_shm_size}",
}
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
else:
self.rearrange_experts_ips_size_signal.value[0] = data_size
self.shm_rearrange_experts_ips_list.shm.buf[:data_size] = data_bytes
content = {"code": 0, "msg": "ok"}
status_code = HTTPStatus.OK
return content, status_code
elif action == "recv_expert_weight":
# action: receive global expert workload, and begin update weight from disk
# params: {'user': 'xxx', 'passwd': 'xxx', 'weight': (layers, experts)}
if "data" not in request_dict or not isinstance(request_dict["data"], list):
content = {"code": 1, "msg": "data not in request or data is not a list"}
status_code = HTTPStatus.BAD_REQUEST
elif len(request_dict["data"]) != len(self.expert_tokens_stats_array_list):
content = {
"code": 1,
"msg": f"actual data length {len(request_dict['data'])}, expect length {len(self.expert_tokens_stats_array_list)}",
}
status_code = HTTPStatus.BAD_REQUEST
else:
weight = np.array(request_dict["data"], dtype=np.int32)
for idx in range(len(self.expert_tokens_stats_array_list)):
self.expert_tokens_stats_array_list[idx].value[:] = weight[:]
self.signal_update_weight_from_disk_array_list[idx].value[0] = 1
content = {"code": 0, "msg": "ok"}
status_code = HTTPStatus.OK
return content, status_code
elif action == "update_weight_from_tensor":
if self.cfg.scheduler_config.splitwise_role != "prefill" and content is None:
content = {
"code": 1,
"msg": f"actual role {self.cfg.scheduler_config.splitwise_role}, expect role prefill",
}
status_code = HTTPStatus.BAD_REQUEST
if self.rearrange_experts_signal.value[0] != RearrangeExpertStatus.LOAD_SUCC.value and content is None:
content = {
"code": 1,
"msg": f"actual status {self.rearrange_experts_signal.value[0]}, expect status {RearrangeExpertStatus.LOAD_SUCC.value}",
}
status_code = HTTPStatus.BAD_REQUEST
if content is None:
self.signal_update_weight_from_tensor_array.value[0] = 1
content = {"code": 0, "msg": "ok"}
status_code = HTTPStatus.OK
return content, status_code
else:
content = {"code": 1, "msg": f"invalid action {action}"}
status_code = HTTPStatus.BAD_REQUEST
return content, status_code
async def get_per_expert_tokens_stats(self, request_dict: dict):
"""
get per expert tokens stats
Args:
request_dict (dict): request body
Returns:
tuple: response body, status code
"""
content, status_code = None, HTTPStatus.OK
eplb_config = self.config.eplb_config
if not eplb_config.enable_eplb:
content = {"code": 1, "msg": "redundant expert is disabled"}
status_code = HTTPStatus.BAD_REQUEST
return content, status_code
if (
request_dict.get("user", "") != eplb_config.redundant_expert_api_user
or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password
):
content = {"code": 1, "msg": "user or passwd is invalid"}
status_code = HTTPStatus.UNAUTHORIZED
return content, status_code
if self.config.parallel_config.expert_parallel_rank != 0:
content = {
"code": 1,
"msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0",
}
status_code = HTTPStatus.BAD_REQUEST
return content, status_code
if "clear_stat" in request_dict and request_dict["clear_stat"]:
for clear_experts_token_stats in self.signal_clear_experts_token_stats_list:
clear_experts_token_stats.value[0] = 1
local_experts_list = []
for local_experts_token_stats in self.local_experts_token_stats_array_list:
local_experts_list.append(local_experts_token_stats.value.tolist())
content = {"code": 0, "msg": "ok", "data": local_experts_list}
status_code = HTTPStatus.OK
return content, status_code
async def check_redundant(self, request_dict: dict):
"""
check redundant
Args:
request_dict (dict): request body
Returns:
tuple: response body, status code
"""
content, status_code = None, HTTPStatus.OK
eplb_config = self.config.eplb_config
if not eplb_config.enable_eplb:
content = {"code": 1, "msg": "redundant expert is disabled"}
status_code = HTTPStatus.BAD_REQUEST
return content, status_code
if (
request_dict.get("user", "") != eplb_config.redundant_expert_api_user
or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password
):
content = {"code": 1, "msg": "user or passwd is invalid"}
status_code = HTTPStatus.UNAUTHORIZED
return content, status_code
if self.config.parallel_config.expert_parallel_rank != 0:
content = {
"code": 1,
"msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0",
}
status_code = HTTPStatus.BAD_REQUEST
return content, status_code
action = request_dict.get("action", "")
if action == "":
status = "unknown"
try:
status = RearrangeExpertStatus(self.rearrange_experts_signal.value[0]).name
except:
pass
content = {"code": 0, "msg": "ok", "status": status}
get_workloads = False if "check_get_workloads" not in request_dict else request_dict["check_get_workloads"]
if get_workloads:
content["data"], content["msg"] = RedundantExpertWorkload(eplb_config.redundant_expert_meta_dir).load()
status_code = HTTPStatus.OK
elif action == "check_load_weight_result":
update_weight_from_disk_list = []
for update_weight_result in self.update_weight_from_disk_result_list:
update_weight_from_disk_list.append(update_weight_result.value[0].tolist())
content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list}
status_code = HTTPStatus.OK
return content, status_code

View File

@@ -155,6 +155,8 @@ async def lifespan(app: FastAPI):
verification = False
model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)]
engine_args = EngineArgs.from_cli_args(args)
config = engine_args.create_engine_config(port_availability_check=False)
engine_client = EngineClient(
model_name_or_path=args.model,
tokenizer=args.tokenizer,
@@ -171,6 +173,7 @@ async def lifespan(app: FastAPI):
workers=args.workers,
tool_parser=args.tool_call_parser,
enable_prefix_caching=args.enable_prefix_caching,
config=config,
)
await engine_client.connection_manager.initialize()
app.state.dynamic_load_weight = args.dynamic_load_weight
@@ -408,6 +411,36 @@ def clear_load_weight(request: Request) -> Response:
return Response(content="Dynamic Load Weight Disabled.", status_code=404)
@app.post("/rearrange_experts")
async def rearrange_experts(request: Request):
"""
rearrange experts
"""
request_dict = await request.json()
content, status_code = await app.state.engine_client.rearrange_experts(request_dict=request_dict)
return JSONResponse(content, status_code=status_code)
@app.post("/get_per_expert_tokens_stats")
async def get_per_expert_tokens_stats(request: Request):
"""
get per expert tokens stats
"""
request_dict = await request.json()
content, status_code = await app.state.engine_client.get_per_expert_tokens_stats(request_dict=request_dict)
return JSONResponse(content, status_code=status_code)
@app.post("/check_redundant")
async def check_redundant(request: Request):
"""
check redundant
"""
request_dict = await request.json()
content, status_code = await app.state.engine_client.check_redundant(request_dict=request_dict)
return JSONResponse(content, status_code=status_code)
def launch_api_server() -> None:
"""
启动http服务

View File

@@ -0,0 +1,15 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

View File

@@ -0,0 +1,427 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import ctypes
import os
import time
import traceback
from typing import List, Tuple
import numpy as np
import paddle
from fastdeploy.config import EPLBConfig
REARRANGE_EXPERT_MAGIC_NUM = 147183647
REARRANGE_ORIGINATOR_EP_RANK = 0
CHECK_TIME_INTERNAL = 3
HTTP_RETRY_NUM = 5
CHECK_TIMEOUT = 120
libc = ctypes.CDLL(None)
libc.mmap.argtypes = [
ctypes.c_void_p, # void *addr
ctypes.c_size_t, # size_t length
ctypes.c_int, # int prot
ctypes.c_int, # int flags
ctypes.c_int, # int fd
ctypes.c_size_t, # off_t offset
]
libc.mmap.restype = ctypes.c_void_p
libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
libc.munmap.restype = ctypes.c_int
PROT_READ = 0x1
PROT_WRITE = 0x2
MAP_SHARED = 0x01
MAP_ANONYMOUS = 0x20
MAP_FAILED = -1
G = 1024**3
TOTAL_MODEL_SIZE = 350
MAIN_MODEL_REDUNDANT_SHM_SIZE = 5
MODEL_MAIN_NAME = "eplb_main"
def create_mmap(model_name: List, ep_rank: int, ep_size: int, shm_uuid: str, eplb_config: EPLBConfig, logger=None):
"""create_mmap"""
flags = MAP_SHARED
prot = PROT_READ | PROT_WRITE
main_size = 0
if eplb_config.redundant_expert_async_load_model_shmem_size_gb == 0:
main_size = TOTAL_MODEL_SIZE // ep_size
else:
main_size = eplb_config.redundant_expert_async_load_model_shmem_size_gb
main_size = main_size * G
mmap_infos = {}
from cuda import cudart
for name in model_name:
expert_weight_file = f"/dev/shm/{name}_rank_{ep_rank}_expert_weight_{shm_uuid}"
shm_size = main_size
if not os.path.isfile(expert_weight_file):
open(expert_weight_file, "wb").close()
shm_fd = os.open(expert_weight_file, os.O_RDWR)
os.ftruncate(shm_fd, shm_size)
if logger is not None:
logger.info(f"redundant_expert: create_mmap file {expert_weight_file}, fd {shm_fd}, size {shm_size}")
shm_ptr = libc.mmap(0, ctypes.c_size_t(shm_size), prot, flags, shm_fd, 0)
if shm_ptr == MAP_FAILED:
raise OSError(f"redundant_expert: mmap {expert_weight_file} failed: {ctypes.get_errno()}")
shm_ptr = ctypes.cast(shm_ptr, ctypes.POINTER(ctypes.c_int8))
addr = ctypes.addressof(shm_ptr.contents)
# Register memory with CUDA
(ret,) = cudart.cudaHostRegister(addr, shm_size, 0)
if ret != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(
f"cudaHostRegister failed: {cudart.cudaGetErrorString(ret)}, "
f" address {hex(addr)} size {shm_size}, ret: {ret}"
)
mmap_infos[name] = shm_ptr
return mmap_infos
def save_tensor_to_shm_mem(cached_weights, file_path, logger=None):
"""save_tensor_to_shm_mem"""
tensor_infos = []
offset = 0
if not os.path.exists(file_path):
raise OSError("File is not exist")
shm_size = os.path.getsize(file_path)
for name, w in cached_weights:
size = w.numel().item() * w.element_size()
# logger.info(f"redundant_expert: save tensor to {name} offset: {offset} size: {size}")
w_ptr = ctypes.string_at(w.data_ptr(), size)
with open(file_path, "r+b") as file:
file.seek(offset)
if offset + size > shm_size:
raise IOError(
f"redundant_expert: Exceeded {file_path} file's size. "
+ "Should set a bigger value using env variable."
)
n = file.write(w_ptr)
assert n == size
tensor_infos.append((name, offset, size, w.shape, w.dtype))
offset += size
sz = offset / 1024 / 1024 / 1024
if logger is not None:
logger.info(f"redundant_expert: save_tensor_to_shm_mem success. file {file_path} size {sz}G")
return tensor_infos
def load_tensor_from_shm_mem(tensor_infos, shm_ptr, logger=None):
"""load_tensor_from_shm_mem"""
# weights_dict = {}
weights_dict = []
for name, offset, size, shape, dtype in tensor_infos:
# 计算共享内存中张量的地址
w_addr = ctypes.cast(shm_ptr, ctypes.c_void_p).value + offset
w_ptr = ctypes.cast(w_addr, ctypes.POINTER(ctypes.c_byte))
# 先读取为字节数组,再通过视图转换成适当类型
np_array = np.ctypeslib.as_array(w_ptr, shape=(size,))
if dtype == paddle.float32:
tmp = np_array.view(np.float32)
tensor = paddle.Tensor(tmp, dtype=paddle.float32, place=paddle.CPUPlace(), zero_copy=True)
elif dtype == paddle.uint8:
tmp = np_array.view(np.uint8)
tensor = paddle.Tensor(tmp, dtype=paddle.uint8, place=paddle.CPUPlace(), zero_copy=True)
elif dtype == paddle.int8:
tmp = np_array.view(np.int8)
tensor = paddle.Tensor(tmp, dtype=paddle.int8, place=paddle.CPUPlace(), zero_copy=True)
elif dtype == paddle.bfloat16:
# NumPy 不支持 bfloat16因此先以 uint16 读取原始数据,再用 Paddle cast 为 bfloat16
tmp = np_array.view(np.uint16)
tensor = paddle.Tensor(tmp, dtype=paddle.bfloat16, place=paddle.CPUPlace(), zero_copy=True)
else:
raise TypeError(f"Unsupported dtype: {dtype}")
assert w_addr == tensor.data_ptr()
# weights_dict[name] = tensor.view(shape)
weights_dict.append((name, tensor.view(shape)))
if logger is not None:
logger.info("redundant_expert: load_tensor_from_shm_mem succ")
return weights_dict
class AsyncEPLoader(object):
"""Aynsc Expert loader"""
def __init__(
self,
model_dir,
eplb_config,
rank=8,
expert_per_rank=8,
moe_layer_start_index=3,
moe_quant_type="",
logger=None,
):
"""
__init__
"""
self.model_path = model_dir
self.eplb_config = eplb_config
self.expert_per_rank = expert_per_rank
self.moe_layer_start_index = moe_layer_start_index
self.ep_rank = rank
self.moe_quant_type = moe_quant_type
self.old_model_ep_rank_to_expert_id_list = None
self.new_model_ep_rank_to_expert_id_list = None
self.cached_weights = []
# self.state_dicts = {}
self.moe_file_names = []
self.logger = logger
def reset(self):
"""
reset
"""
self.old_model_ep_rank_to_expert_id_list = None
self.new_model_ep_rank_to_expert_id_list = None
self.cached_weights = []
self.moe_file_names = []
def load_experts_weight_from_disk(self):
"""
return value: (all_succ whether_load_weight exist_fatal_error message),
exist_fatal_error means all rank need restart
"""
ep_rank = self.ep_rank
start_idx = ep_rank * self.expert_per_rank
end_idx = start_idx + self.expert_per_rank
try:
old_expert_ids_all = self.old_model_ep_rank_to_expert_id_list[:, start_idx:end_idx]
new_expert_ids_all = self.new_model_ep_rank_to_expert_id_list[:, start_idx:end_idx]
need_to_reload = list()
for layer_id in range(len(old_expert_ids_all)):
if layer_id < self.moe_layer_start_index:
continue
new_expert_ids = new_expert_ids_all[layer_id]
old_expert_ids = old_expert_ids_all[layer_id]
if len(new_expert_ids) != len(old_expert_ids):
message = f"redundant_expert: new_expert_ids length not equal to old_expert_ids \
length layer_id: {layer_id}"
# this is very dangerous and unepxpected, should be fixed
return False, message
# TODO: 按需加载,过滤重复专家
self.logger.info(
f"redundant_expert: rank {ep_rank} layer {layer_id} old_experts {old_expert_ids}"
+ f" new_experts {new_expert_ids}"
)
need_to_reload.extend([(layer_id, expert_id) for expert_id in new_expert_ids])
succ = True
message = ""
if len(need_to_reload) > 0:
if self.eplb_config.model_use_safetensors:
succ, message = self.load_safetensor_fp8_from_disk(need_to_reload)
else:
succ, message = self.load_weight_bf16_from_disk(need_to_reload)
if not succ:
self.logger.info(
f"redundant_expert: load_experts_weight_from_disk fail. rank {ep_rank}, error: {message}"
)
new_message = f"redundant_expert: load_experts_weight_from_disk fail. rank {ep_rank}, error: {message}"
return False, new_message
self.logger.info(f"redundant_expert: load_experts_weight_from_disk success. rank {ep_rank}")
return True, "redundant_expert: load_experts_weight_from_disk success"
except Exception as e:
message = f"redundant_expert: Failed to load_experts_weight_from_disk ep_rank {ep_rank} excep: {e}"
error_message = traceback.format_exc()
self.logger.error(f"redundant_expert: message: {message} traceback: {error_message}")
return False, message
def load_weight_bf16_from_disk(self, need_to_reload: List[Tuple[int, int]]):
"""load_weight_bf16_from_disk"""
try:
ckpt_up_gate_proj_name = "up_gate_proj"
ckpt_down_proj_name = "down_proj"
for layer_id, expert_id in need_to_reload:
for weight_name in [ckpt_up_gate_proj_name, ckpt_down_proj_name]:
ckpt_file_name = f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{weight_name}.weight"
if ckpt_file_name not in self.moe_file_names:
self.logger.info(f"record redundant_expert: {ckpt_file_name}")
self.moe_file_names.append(ckpt_file_name)
last_device = paddle.device.get_device()
paddle.set_device("cpu")
for file_name in self.moe_file_names:
# 判断文件是否存在
if not os.path.exists(self.model_path + "/merged_tp1_state_split/" + file_name):
# self.logger.info(f"redundant_expert: {file_name} not exist.")
continue
# self.logger.info(f"redundant_expert: Loading expert weights: {file_name}.")
self.state_dicts[file_name] = paddle.load(self.model_path + "/merged_tp1_state_split/" + file_name)
paddle.set_device(last_device)
self.logger.info("redundant_expert: Loading expert weights end.")
return True, "redundant_expert: Succeeded to loading expert weights."
except Exception as e:
message = f"redundant_expert: Failed to get weights iterator: {e}."
return False, message
def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]):
"""load_safetensor_fp8_from_disk"""
"""
ernie.layers.52.mlp.experts.58.up_gate_proj.quant_weight
ernie.layers.52.mlp.experts.58.up_gate_proj.weight_scale
ernie.layers.52.mlp.experts.58.down_proj.quant_weight
ernie.layers.52.mlp.experts.58.down_proj.weight_scale
"""
up_gate_down = ["up_gate_proj", "down_proj"]
quant_weight_scale = ["quant_weight", "weight_scale"]
if self.moe_quant_type == "w4a8":
quant_weight_scale = ["quant_weight"]
ckpt_name = [
(f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{proj_name}.{quant_name}")
for layer_id, expert_id in need_to_reload
for proj_name in up_gate_down
for quant_name in quant_weight_scale
]
ckpt_name_to_safetensor_file = load_ep_checkpoint(self.model_path)
hf_weights_files = list(set(ckpt_name_to_safetensor_file.values()))
state_dicts = {}
last_device = paddle.device.get_device()
paddle.set_device("cpu")
from safetensors import safe_open
for st_file in hf_weights_files:
with safe_open(st_file, framework="np", device="cpu") as f:
for name in f.keys():
if name in ckpt_name:
weight = f.get_tensor(name)
state_dicts[name] = paddle.Tensor(weight, zero_copy=True)
weights_list = []
for name in ckpt_name:
weights_list.append((name, state_dicts[name]))
self.cached_weights = weights_list
paddle.set_device(last_device)
return True, "load_expert_weight_from_disk_safetensor success"
def load_ep_checkpoint(model_path):
"""
load ep checkpoint
"""
file_path = os.path.join(model_path, "model.safetensors.index.json")
if not os.path.exists(file_path):
return {}
import json
with open(file_path, "r") as f:
weight_map = json.load(f)["weight_map"]
state_dict = {k: os.path.join(model_path, v) for k, v in weight_map.items()}
return state_dict
def load_model_weights_process(
rank: int,
model_dir: str,
expert_per_rank: int,
moe_layer_start_index: int,
moe_quant_type: str,
shm_uuid: str,
eplb_config: EPLBConfig,
data_conn,
mg_conn,
):
"""
load_model_weights_process
"""
import faulthandler
from setproctitle import setproctitle
setproctitle(f"eplb::async_load_model_{rank}")
faulthandler.enable()
from fastdeploy.utils import get_logger
logger = get_logger("eplb_async_loader", "eplb_{0}.log".format(rank))
logger.info("redundant_expert: load_model_weights_process start")
paddle.set_device("cpu")
ep_loader = AsyncEPLoader(
model_dir=model_dir,
rank=rank,
expert_per_rank=expert_per_rank,
moe_layer_start_index=moe_layer_start_index,
moe_quant_type=moe_quant_type,
logger=logger,
eplb_config=eplb_config,
)
while True:
ep_loader.reset()
data = mg_conn.recv()
result = True
weight_infos = []
try:
ep_loader.old_model_ep_rank_to_expert_id_list = data["old_model_ep_rank_to_expert_id_list"]
ep_loader.new_model_ep_rank_to_expert_id_list = data["new_model_ep_rank_to_expert_id_list"]
begin_time_disk = int(time.time())
success, message = ep_loader.load_experts_weight_from_disk()
begin_time_shm = int(time.time())
logger.info(
"redundant_expert: async load load_weight_from_disk, "
+ f"succ {success}, cost {begin_time_shm-begin_time_disk}s"
)
if success:
model_name = MODEL_MAIN_NAME
file_path = f"/dev/shm/{model_name}_rank_{rank}_expert_weight_{shm_uuid}"
weight_infos = save_tensor_to_shm_mem(ep_loader.cached_weights, file_path, logger)
logger.info(
"redundant_expert: async load save_tensor_to_shm_mem, "
+ f"tensor nums {len(weight_infos)}, cost {int(time.time()-begin_time_shm)}s"
)
else:
logger.error(f"redundant_expert: async load load_weight_from_disk failed, error {message}")
result = False
except Exception as e:
logger.error(f"redundant_expert: async load weights failed, rank {rank} error {e}")
result = False
weight_infos = []
finally:
request_data = {"result": result, "weights": weight_infos}
data_conn.send(request_data)

291
fastdeploy/eplb/eplb.py Normal file
View File

@@ -0,0 +1,291 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Tuple
import numpy as np
def balanced_packing(weight: np.ndarray, num_packs: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
if groups_per_pack == 1:
pack_index = np.arange(weight.shape[-1], dtype=np.int32).reshape(1, -1).repeat(num_layers, axis=0)
rank_in_pack = np.zeros_like(weight, dtype=np.int32)
return pack_index, rank_in_pack
indices = np.argsort(-weight.astype(np.float32), axis=-1)
pack_index = np.full_like(weight, fill_value=-1, dtype=np.int32)
rank_in_pack = np.full_like(pack_index, fill_value=-1)
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:
pack = min(
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
pack_weights[pack] += weight[i, group]
pack_items[pack] += 1
return pack_index, rank_in_pack
def replicate_experts(weight: np.ndarray, num_phy: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
phy2log = np.arange(num_phy, dtype=np.int32).reshape(1, -1).repeat(n, axis=0)
rank = np.zeros((n, num_phy), dtype=np.int32)
logcnt = np.ones((n, num_log), dtype=np.int32)
arangen = np.arange(n, dtype=np.int32)
for i in range(num_log, num_phy):
redundant_indices = np.argmax(weight / logcnt, axis=-1)
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
def rebalance_experts_intra_node(
weight: np.ndarray,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
num_redundant_experts = num_physical_experts - num_logical_experts
assert num_redundant_experts >= 0
assert num_gpus % num_nodes == 0
num_gpus_per_node = num_gpus // num_nodes
assert num_physical_experts % num_gpus == 0
num_physical_experts_per_gpu = num_physical_experts // num_gpus
assert num_physical_experts % num_nodes == 0
num_physical_experts_per_node = num_physical_experts // num_nodes
assert num_logical_experts % num_physical_experts_per_node == 0
# num_logical_nodes = num_logical_experts // num_physical_experts_per_node
assert num_redundant_experts % num_physical_experts_per_node == 0
# num_redundant_nodes = num_redundant_experts // num_physical_experts_per_node
def inverse(perm: np.ndarray) -> np.ndarray:
inv = np.empty_like(perm)
inv[np.arange(perm.shape[0])[:, None], perm] = np.arange(perm.shape[1], dtype=np.int32).reshape(1, -1)
return inv
# Step 1: generate redundant experts by weight.
# shape of tmp2log, tmprank is [num_layers, num_physical_experts]
# shape of logcnt is [num_layers, num_logical_experts]
tmp2log, tmprank, logcnt = replicate_experts(weight, num_physical_experts)
# Step 2: compute num_tokens of physical experts
# shape of tokens_per_tmp is [num_layers * num_nodes, num_physical_experts_per_node]
tokens_per_tmp = np.take_along_axis(weight / logcnt, tmp2log, axis=-1).reshape(-1, num_physical_experts_per_node)
# STEP 3: take load balance of gpu cards in node
# shape of gpu_index, rank_in_gpu, tmp2phy, phy2tmp is [num_layers * num_nodes, num_physical_experts_per_node]
gpu_index, rank_in_gpu = balanced_packing(tokens_per_tmp, num_gpus_per_node)
tmp2phy = gpu_index * num_physical_experts_per_gpu + rank_in_gpu
phy2tmp = inverse(tmp2phy)
# STEP 4: generate final phy2log mapping
tmp2log = tmp2log.reshape(-1, num_physical_experts_per_node)
tmprank = tmprank.reshape(-1, num_physical_experts_per_node)
phy2log = np.take_along_axis(tmp2log, phy2tmp, axis=-1).reshape(-1, num_physical_experts)
phyrank = np.take_along_axis(tmprank, phy2tmp, axis=-1).reshape(-1, num_physical_experts)
return phy2log, phyrank, logcnt
def rebalance_experts_hierarchical(
weight: np.ndarray,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: np.ndarray) -> np.ndarray:
inv = np.empty_like(perm)
inv[np.arange(perm.shape[0])[:, None], perm] = np.arange(perm.shape[1], dtype=np.int32).reshape(1, -1)
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(axis=-1)
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
log2mlog = (
((group_pack_index * groups_per_node + group_rank_in_pack) * group_size)[:, :, None]
+ np.arange(group_size, dtype=np.int32)
).reshape(num_layers, -1)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=-1).reshape(-1, num_logical_experts // num_nodes)
phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes)
# Step 3: pack physical_experts to GPUs
tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=-1)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=-1) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (
pphy2mlog.reshape(num_layers, num_nodes, -1)
+ np.arange(0, num_logical_experts, num_logical_experts // num_nodes, dtype=np.int32).reshape(1, -1, 1)
).reshape(num_layers, -1)
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=-1)
pphyrank = np.take_along_axis(phyrank, pphy2phy, axis=-1).reshape(num_layers, -1)
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=-1)
return pphy2log, pphyrank, logcnt
def rebalance_experts(
weight: np.ndarray,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
eplb_strategy: str = "",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.astype(np.float32)
if eplb_strategy == "balance_intra_node":
phy2log, phyrank, logcnt = rebalance_experts_intra_node(weight, num_replicas, num_groups, num_nodes, num_gpus)
else:
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas)
maxlogcnt = logcnt.max()
log2phy = np.full((num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int32)
np.put_along_axis(
log2phy.reshape(num_layers, -1)[:, :, None],
(phy2log * maxlogcnt + phyrank)[:, :, None],
np.arange(num_replicas, dtype=np.int32).reshape(1, -1).repeat(num_layers, axis=0)[:, :, None],
axis=1,
)
return phy2log, log2phy, logcnt
__all__ = ["rebalance_experts"]
def main():
"""
main
"""
num_hidden_layers = 3
num_expert = 64
num_groups = 8
num_replicas = 64
num_nodes = 4
num_gpus = 4 * 8
model_tokens_per_expert_stats_list = np.random.randint(low=1, high=10, size=(num_hidden_layers, num_expert))
phy2log, phyrank, logcnt = rebalance_experts(
model_tokens_per_expert_stats_list,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
print(phy2log)
print(phyrank)
print(logcnt)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,503 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
from http import HTTPStatus
from multiprocessing import Pipe, Process
import numpy as np
import requests
from fastdeploy.config import FDConfig
from fastdeploy.eplb.async_expert_loader import load_model_weights_process
from fastdeploy.eplb.eplb import rebalance_experts
from fastdeploy.eplb.utils import RedundantExpertWorkload
from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus
from fastdeploy.utils import get_logger
class RedundantExpertManager:
"""
RedundantExpertManger
"""
def __init__(
self,
rank: int = 0,
ep_size: int = 32,
fd_config: FDConfig = None,
ipc_signal_suffix: int = 0,
):
self.logger = get_logger("eplb_expert_manager", "eplb_{0}.log".format(rank))
self.rank = rank
self.ep_size = ep_size
self.fd_config = fd_config
self.eplb_config = fd_config.eplb_config
self.api_user = self.eplb_config.redundant_expert_api_user
self.api_passwd = self.eplb_config.redundant_expert_api_password
self.num_redundant_experts = self.eplb_config.redundant_experts_num
self.num_hidden_layers = self.fd_config.model_config.num_hidden_layers
self.num_logical_experts = self.fd_config.model_config.moe_num_experts
self.ipc_signal_suffix = ipc_signal_suffix
self.num_replicas = self.num_logical_experts + self.num_redundant_experts
self.num_groups = self.num_logical_experts
self.num_nodes = max(ep_size // 8, 1)
self.num_gpus = ep_size
self.expert_per_rank = self.num_replicas // ep_size
assert (
self.num_replicas % ep_size == 0
), f"num_replicas must be divisible by ep_size, \
but got num_replicas = {self.num_replicas}, ep_size = {ep_size}"
self.model_ep_rank_to_expert_id_list = np.full(
(
self.num_hidden_layers,
self.num_logical_experts + self.num_redundant_experts,
),
-1,
dtype=np.int32,
)
self.model_expert_id_to_ep_rank_array = np.full(
(
self.num_hidden_layers,
self.num_logical_experts,
self.num_redundant_experts + 1,
),
-1,
dtype=np.int32,
)
self.model_expert_in_rank_num_list = np.zeros(
(self.num_hidden_layers, self.num_logical_experts), dtype=np.int32
)
# backup info
self.last_model_ep_rank_to_expert_id_list = np.full(
(
self.num_hidden_layers,
self.num_logical_experts + self.num_redundant_experts,
),
-1,
dtype=np.int32,
)
self.last_model_expert_id_to_ep_rank_array = np.full(
(
self.num_hidden_layers,
self.num_logical_experts,
self.num_redundant_experts + 1,
),
-1,
dtype=np.int32,
)
self.last_model_expert_in_rank_num_list = np.zeros(
(self.num_hidden_layers, self.num_logical_experts), dtype=np.int32
)
self.model_tokens_per_expert_stats_list = np.ones(
(self.num_hidden_layers, self.num_logical_experts), dtype=np.int32
)
self.caculate_expert_rank_table(True)
self.dp_rank_address = None
self.need_allgather_load_weight_result = False
self.load_weight_begin_ts = 0
self.load_weight_timeout = 300 # 5min
self.need_rearrange_expert = False
self.need_update_expert_tokens_stat = True
self.http_timeout = 1
# 重置重排状态: 'done' -> 'free'
self.rearrange_end_ts = 0
self.rearrange_reset_interval = 300
self.tensor_infos = None
self.parent_data_conn, child_data_conn = Pipe()
self.parent_mg_conn, child_mg_conn = Pipe()
Process(
target=load_model_weights_process,
name=f"eplb::async_load_model_{rank}",
args=(
self.rank,
self.fd_config.model_config.model,
self.expert_per_rank,
self.fd_config.model_config.moe_layer_start_index,
self.eplb_config.moe_quant_type,
self.ipc_signal_suffix,
self.eplb_config,
child_data_conn,
child_mg_conn,
),
).start()
child_data_conn.close()
child_mg_conn.close()
listen_signal_thread = threading.Thread(target=self.listen_rearrange_expert_signal, args=(), daemon=True)
listen_signal_thread.start()
self.logger.info(
f"redundant_expert: RedundantExpertManager init success, rank {rank}, \
strategy {self.eplb_config.redundant_expert_eplb_strategy}"
)
# def get_unique_name(self, name):
# return f"{envs.get_unique_name(name + '_dprank_' + str(self.rank))}"
def get_ep_rank_to_expert_id_list(self):
"""
get_ep_rank_to_expert_id_list
"""
return (
self.model_ep_rank_to_expert_id_list,
self.model_expert_id_to_ep_rank_array,
self.model_expert_in_rank_num_list,
)
def listen_rearrange_expert_signal(self):
"""
listen_rearrange_expert_signal
"""
if self.rank == 0:
rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32)
rearrange_experts_ips_size_signal = IPCSignal(
name="rearrange_experts_ips_size",
array=rearrange_experts_ips_size_array,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=False,
)
shm_rearrange_experts_ips_list = IPCSignal(
name="rearrange_experts_ips_list",
shm_size=self.eplb_config.redundant_expert_ip_shm_size,
suffix=self.ipc_signal_suffix,
create=False,
)
rearrange_experts_status = np.zeros([1], dtype=np.int32)
rearrange_experts_signal = IPCSignal(
name="rearrange_experts_status",
array=rearrange_experts_status,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=False,
)
signal_update_weight_from_disk = np.zeros([1], dtype=np.int32)
signal_update_weight_from_disk_array = IPCSignal(
name="signal_update_weight_from_disk",
array=signal_update_weight_from_disk,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=False,
)
experts_token_stats = np.zeros(
(self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.moe_num_experts),
dtype=np.int32,
)
shm_all_experts_token_stats = IPCSignal(
name="all_experts_token_stats",
array=experts_token_stats,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=False,
)
while True:
if self.rank == 0:
now = int(time.time())
if rearrange_experts_ips_size_signal.value[0] > 0:
# step 1. all reduce experts token stats
address = bytes(
shm_rearrange_experts_ips_list.shm.buf[: rearrange_experts_ips_size_signal.value[0]]
).decode("utf-8")
self.logger.info(f"redundant_expert: all rank ips {address}")
rearrange_experts_ips_size_signal.value[0] = 0
rearrange_experts_signal.value[0] = RearrangeExpertStatus.DOING.value
self.dp_rank_address = address.strip().split(";")
if self.allreduce_experts_stat():
self.need_allgather_load_weight_result = True
self.load_weight_begin_ts = now
self.logger.info("redundant_expert: all-reduce experts stats success")
else:
rearrange_experts_signal.value[0] = RearrangeExpertStatus.FREE.value
self.logger.warning("redundant_expert: all-reduce experts stats fail")
elif self.need_allgather_load_weight_result and self.allreduce_load_weight_result():
# step 3. all reduce the result of load weight from disk
self.need_allgather_load_weight_result = False
rearrange_experts_signal.value[0] = RearrangeExpertStatus.LOAD_SUCC.value
self.rearrange_end_ts = now
if rearrange_experts_signal.value[0] > 1 and (
now - self.rearrange_end_ts > self.rearrange_reset_interval
):
# reset rearrange status
rearrange_experts_signal.value[0] = RearrangeExpertStatus.FREE.value
if signal_update_weight_from_disk_array.value[0] == 1:
# step 2. async load weight: disk -> memory
self.model_tokens_per_expert_stats_list[:] = shm_all_experts_token_stats.value[:]
self.caculate_expert_rank_table()
self.update_weight_from_disk()
signal_update_weight_from_disk_array.value[0] = 0
time.sleep(0.5)
def caculate_expert_rank_table(self, is_init=False):
"""
caculate_expert_rank_table
"""
num_groups = self.num_groups
num_nodes = self.num_nodes
num_gpus = self.num_gpus
eplb_strategy = self.eplb_config.redundant_expert_eplb_strategy
if is_init:
num_groups = 1
num_nodes = 2
num_gpus = 2 * 8
eplb_strategy = ""
# eplb
rank_expert_list, logical_to_physical_map, expert_count = rebalance_experts(
self.model_tokens_per_expert_stats_list,
self.num_replicas,
num_groups,
num_nodes,
num_gpus,
eplb_strategy,
)
# backup info
self.last_model_ep_rank_to_expert_id_list[:] = self.model_ep_rank_to_expert_id_list[:]
self.last_model_expert_id_to_ep_rank_array[:] = self.model_expert_id_to_ep_rank_array[:]
self.last_model_expert_in_rank_num_list[:] = self.model_expert_in_rank_num_list[:]
# update model info
self.model_ep_rank_to_expert_id_list[:] = rank_expert_list[:]
self.model_expert_id_to_ep_rank_array.fill(-1)
self.model_expert_id_to_ep_rank_array[..., : logical_to_physical_map.shape[-1]] = logical_to_physical_map[:]
self.model_expert_in_rank_num_list[:] = expert_count[:]
if self.rank == 0:
workload = RedundantExpertWorkload()
workload.tokens_per_expert_stats_list = self.model_tokens_per_expert_stats_list.tolist()
workload.ep_rank_to_expert_id_list = rank_expert_list.tolist()
workload.expert_id_to_ep_rank_array = logical_to_physical_map.tolist()
workload.expert_in_rank_num_list = expert_count.tolist()
self.logger.info(workload.dump())
def update_weight_from_disk(self):
"""
update_weight_from_disk
"""
begin_time = time.time()
result_update_weight_from_disk = np.zeros([1], dtype=np.int32)
update_weight_from_disk_result = IPCSignal(
name="result_update_weight_from_disk",
array=result_update_weight_from_disk,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=False,
)
update_weight_from_disk_result.value[0] = 0
self.logger.info(f"redundant_expert: update_weight_from_disk send to async process, rank {self.rank}")
self.parent_mg_conn.send(
{
"old_model_ep_rank_to_expert_id_list": self.last_model_ep_rank_to_expert_id_list,
"new_model_ep_rank_to_expert_id_list": self.model_ep_rank_to_expert_id_list,
}
)
self.logger.info(f"redundant_expert: update_weight_from_disk recv from async process, rank {self.rank}")
response = self.parent_data_conn.recv()
self.tensor_infos = response["weights"]
# 更新权重加载结果
update_weight_from_disk_result.value[0] = 1 if response["result"] else -1
self.logger.info(
"redundant_expert: update_weight_from_disk end, rank"
+ f" {self.rank} {response['result']}, cost {int(time.time() - begin_time)}s"
)
def allreduce_experts_stat(self):
"""
专家负载
"""
if not self.allgather_expert_token_stats():
return False
return self.broadcast_expert_token_stats()
def allgather_expert_token_stats(self):
"""
allgather_expert_token_stats
"""
expert_token_stats = np.zeros((self.num_hidden_layers, self.num_logical_experts), dtype=np.int32)
success_count = 0
for addr in self.dp_rank_address:
try:
# TODO: 请求失败重试
params = {"user": self.api_user, "passwd": self.api_passwd}
res = requests.post(
f"http://{addr}/get_per_expert_tokens_stats",
json=params,
timeout=self.http_timeout,
)
if res.status_code != HTTPStatus.OK:
self.logger.warning(
"redundant_expert: allgather_expert_token_stats fail. "
+ f"addr {addr}, res {res.status_code} {res.json()}"
)
break
for meta_data in res.json()["data"]:
expert_token_stats += np.array(meta_data, dtype=np.int32)
success_count += 1
except Exception as e:
self.logger.error(f"redundant_expert: allgather_expert_token_stats fail. addr {addr}, error {e}")
if success_count == len(self.dp_rank_address):
self.need_rearrange_expert = True
self.model_tokens_per_expert_stats_list[:] = expert_token_stats[:]
self.logger.info("redundant_expert: allgather_expert_token_stats success")
return True
self.logger.info(
"redundant_expert: allgather_expert_token_stats fail. "
+ f"succ {success_count} total {len(self.dp_rank_address)}"
)
return False
def broadcast_expert_token_stats(self):
"""
broadcast_expert_token_stats
"""
success_count = 0
for addr in self.dp_rank_address:
try:
params = {
"user": self.api_user,
"passwd": self.api_passwd,
"action": "recv_expert_weight",
"data": self.model_tokens_per_expert_stats_list.tolist(),
}
res = requests.post(
f"http://{addr}/rearrange_experts",
json=params,
timeout=self.http_timeout,
)
if res.status_code != HTTPStatus.OK:
self.logger.warning(
"redundant_expert: broadcast_expert_token_stats fail. "
+ f"addr {addr}, res {res.status_code} {res.json()}"
)
break
success_count += 1
except Exception as e:
self.logger.error(
f"redundant_expert: broadcast_expert_token_stats request fail. addr {addr}, error {e}"
)
if success_count == len(self.dp_rank_address):
self.logger.info("redundant_expert: broadcast_expert_token_stats success")
return True
self.logger.info(
"redundant_expert: broadcast_expert_token_stats failed, "
+ f"succ {success_count} total {len(self.dp_rank_address)}"
)
return False
def allreduce_load_weight_result(self):
"""
权重加载结果
"""
if int(time.time()) - self.load_weight_begin_ts > self.load_weight_timeout:
self.logger.info(f"redundant_expert: allreduce_load_weight_result timeout {self.load_weight_timeout}s")
return True
all_success, exist_fail = self.allgather_load_weight_result()
if exist_fail:
# 如果有DP权重加载异常结束本次重排
self.logger.warning("redundant_expert: allreduce_load_weight_result exist fail, terminate this rearrange")
return True
if not all_success:
self.logger.info("redundant_expert: allreduce_load_weight_result waiting")
return False
# self.broadcast_load_weight_success()
if not exist_fail and all_success:
# prefill需要等待调度屏蔽
if (
self.fd_config.splitwise_role == "decode"
or not self.eplb_config.redundant_expert_enable_schedule_cordon
):
self.logger.info("redundant_expert: allreduce_load_weight_result success, notify infer.py")
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
signal_update_weight_from_tensor_array = IPCSignal(
name="signal_update_weight_from_tensor",
array=signal_update_weight_from_tensor,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=False,
)
signal_update_weight_from_tensor_array.value[0] = 1
return True
def allgather_load_weight_result(self):
"""
allgather_load_weight_result
"""
all_success, exist_fail = False, False
success_count, fail_count = 0, 0
for addr in self.dp_rank_address:
try:
params = {
"user": self.api_user,
"passwd": self.api_passwd,
"action": "check_load_weight_result",
}
res = requests.post(
f"http://{addr}/check_redundant",
json=params,
timeout=self.http_timeout,
)
if res.status_code != HTTPStatus.OK:
self.logger.warning(
"redundant_expert: allgather_load_weight_result fail. "
+ f"addr {addr}, res {res.status_code} {res.json()}"
)
break
result_list = res.json()["data"]
self.logger.info(
f"redundant_expert: allgather_load_weight_result success. addr {addr}, result_list {result_list}"
)
for result in result_list:
if result == 1:
success_count += 1
elif result == -1:
fail_count += 1
self.logger.error(
f"redundant_expert: allgather_load_weight_result fail. addr {addr}, result {result}"
)
exist_fail = True
except Exception as e:
self.logger.error(f"redundant_expert: allgather_load_weight_result error. addr {addr}, error {e}")
if fail_count > 0:
self.logger.info(
"redundant_expert: allgather_load_weight_result not all ready, "
+ f"succ {success_count} fail {fail_count} total {len(self.dp_rank_address)}"
)
else:
self.logger.info("redundant_expert: allgather_load_weight_result all success")
all_success = True
return all_success, exist_fail

160
fastdeploy/eplb/utils.py Normal file
View File

@@ -0,0 +1,160 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import os
import time
import numpy as np
from fastdeploy.config import FDConfig
from fastdeploy.inter_communicator import IPCSignal
class RedundantExpertWorkload:
"""Redundant Expert Workload"""
def __init__(self, redundant_expert_meta_dir="/tmp/redundant_expert_meta"):
self.update_timestamp = time.time()
self.tokens_per_expert_stats_list = None
self.ep_rank_to_expert_id_list = None
self.expert_id_to_ep_rank_array = None
self.expert_in_rank_num_list = None
self.cost_milliseconds = 0
self.meta_file_name = f"{redundant_expert_meta_dir}/rearrange-experts.json"
if not os.path.exists(redundant_expert_meta_dir):
os.makedirs(redundant_expert_meta_dir, exist_ok=True)
def __json__(self):
return self.__dict__
def dump(self):
"""Dump the object to a JSON file."""
begin = time.time()
try:
with open(self.meta_file_name, "w") as fout:
json.dump(self.__dict__, fout)
except Exception as e:
return f"redundant_expert: dump expert workload failed, {e}"
cost_time = int((time.time() - begin) * 1000 * 1000)
return f"redundant_expert: dump expert workload result in {cost_time} us"
def load(self):
"""Load the object from a JSON file."""
if not os.path.exists(self.meta_file_name):
return {}, f"redundant_expert: file {self.meta_file_name} is not exists"
try:
with open(self.meta_file_name, "r") as fin:
meta = json.load(fin)
self.__dict__.update(meta)
return self.__json__(), "ok"
except Exception as e:
return {}, f"redundant_expert: load file {self.meta_file_name} failed, {e}"
def init_eplb_signals(config: FDConfig, ipc_signal_suffix):
"""
Initialize shared memory to indicate eplb status
"""
if config.parallel_config.local_data_parallel_id == 0:
# rearrange_experts_status Record the expert's rearrangement status
rearrange_experts_array = np.zeros([1], dtype=np.int32)
_ = IPCSignal(
name="rearrange_experts_status",
array=rearrange_experts_array,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=True,
)
# Record all DP rank IPs when receiving expert rearrangement requests
rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32)
_ = IPCSignal(
name="rearrange_experts_ips_size",
array=rearrange_experts_ips_size_array,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=True,
)
_ = IPCSignal(
name="rearrange_experts_ips_list",
shm_size=config.eplb_config.redundant_expert_ip_shm_size,
suffix=ipc_signal_suffix,
create=True,
)
# Receive signals for updating weights
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
_ = IPCSignal(
name="signal_update_weight_from_tensor",
array=signal_update_weight_from_tensor,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=True,
)
# Record expert workload
experts_token_stats = np.zeros(
(config.model_config.num_hidden_layers, config.model_config.moe_num_experts),
dtype=np.int32,
)
_ = IPCSignal(
name="all_experts_token_stats",
array=experts_token_stats,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=True,
)
_ = IPCSignal(
name="local_experts_token_stats",
array=experts_token_stats,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=True,
)
# Receive signals for loading weights
signal_update_weight_from_disk = np.zeros([1], dtype=np.int32)
_ = IPCSignal(
name="signal_update_weight_from_disk",
array=signal_update_weight_from_disk,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=True,
)
# Receive signals for clearing expert loads
clear_experts_token_stats = np.zeros([1], dtype=np.int32)
_ = IPCSignal(
name="signal_clear_experts_token_stats",
array=clear_experts_token_stats,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=True,
)
result_update_weight_from_disk = np.zeros([1], dtype=np.int32)
_ = IPCSignal(
name="result_update_weight_from_disk",
array=result_update_weight_from_disk,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=True,
)
if __name__ == "__main__":
print(RedundantExpertWorkload("/tmp").load())

View File

@@ -17,6 +17,7 @@
from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal, shared_memory_exists
from .ipc_signal_const import RearrangeExpertStatus
from .zmq_client import ZmqIpcClient
from .zmq_server import ZmqIpcServer, ZmqTcpServer
@@ -28,4 +29,5 @@ __all__ = [
"ZmqTcpServer",
"ZmqIpcServer",
"shared_memory_exists",
"RearrangeExpertStatus",
]

View File

@@ -55,10 +55,11 @@ class IPCSignal:
def __init__(
self,
name: str,
array: np.ndarray,
dtype: np.dtype,
array: np.ndarray = None,
dtype: np.dtype = None,
suffix: int = None,
create: bool = True,
shm_size: int = None,
) -> None:
"""Initialize or connect to a shared memory block.
@@ -72,23 +73,36 @@ class IPCSignal:
Raises:
AssertionError: If create=True but memory already exists, or dtype mismatch.
"""
assert isinstance(array, np.ndarray), "Input must be a numpy array"
assert dtype == array.dtype, "Specified dtype must match array dtype"
if dtype is None or array is None:
assert shm_size is not None, "shm_size must be specified if array and dtype are None"
# Set a suffix for name to avoid name conflict while there are multiple engine launched
if suffix is not None:
name = name + f".{suffix}"
if create:
if shared_memory_exists(name):
llm_logger.warning(f"ShareMemory: {name} already exists, delete it")
SharedMemory(name=name, create=False).unlink()
self.shm = SharedMemory(create=True, size=array.nbytes, name=name)
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
self.value[:] = array # Initialize with input array data
if create:
llm_logger.debug(f"creating ipc signal: {name}")
if shared_memory_exists(name):
llm_logger.warning(f"ShareMemory: {name} already exists, delete it")
SharedMemory(name=name, create=False).unlink()
self.shm = SharedMemory(create=True, size=shm_size, name=name)
else:
llm_logger.debug(f"attaching ipc signal: {name}")
self.shm = SharedMemory(name=name)
else:
self.shm = SharedMemory(name=name)
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
assert isinstance(array, np.ndarray), "Input must be a numpy array"
assert dtype == array.dtype, "Specified dtype must match array dtype"
# Set a suffix for name to avoid name conflict while there are multiple engine launched
if suffix is not None:
name = name + f".{suffix}"
if create:
if shared_memory_exists(name):
llm_logger.warning(f"ShareMemory: {name} already exists, delete it")
SharedMemory(name=name, create=False).unlink()
self.shm = SharedMemory(create=True, size=array.nbytes, name=name)
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
self.value[:] = array # Initialize with input array data
else:
self.shm = SharedMemory(name=name)
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
def clear(self) -> None:
"""Release system resources and unlink the shared memory block."""

View File

@@ -0,0 +1,26 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from enum import Enum
@dataclass
class RearrangeExpertStatus(Enum):
FREE = 0
DOING = 1
LOAD_SUCC = 2 # load weight from disk success
DONE = 3

View File

@@ -388,7 +388,7 @@ class Ernie4_5_Model(nn.Layer):
fd_config.model_config.pretrained_config.prefix_name = "ernie"
self.fd_config = fd_config
self.redundant_table_manger = None
if fd_config.model_config.enable_redundant_experts is True:
if fd_config.eplb_config.enable_eplb is True:
self.redundant_table_manger = RedundantExpertManger(
n_routed_experts=fd_config.model_config.moe_num_experts,
num_hidden_layers=fd_config.model_config.num_hidden_layers,

View File

@@ -66,6 +66,7 @@ class RolloutModelConfig:
num_nextn_predict_layers: int = 0,
enable_attention_dp_balance: bool = False,
attention_dp_time_out_iters: int = 0,
eplb_config: str = {},
):
# Required parameters
self.model = model_name_or_path
@@ -115,6 +116,7 @@ class RolloutModelConfig:
self.num_nextn_predict_layers = num_nextn_predict_layers
self.enable_attention_dp_balance = enable_attention_dp_balance
self.attention_dp_time_out_iters = attention_dp_time_out_iters
self.eplb_config = eplb_config
def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())

View File

@@ -30,6 +30,7 @@ from fastdeploy.config import (
DecodingConfig,
DeviceConfig,
EarlyStopConfig,
EPLBConfig,
ErnieArchitectures,
FDConfig,
GraphOptimizationConfig,
@@ -40,9 +41,16 @@ from fastdeploy.config import (
SpeculativeConfig,
)
from fastdeploy.engine.request import RequestType
from fastdeploy.eplb.async_expert_loader import (
MODEL_MAIN_NAME,
REARRANGE_EXPERT_MAGIC_NUM,
create_mmap,
load_tensor_from_shm_mem,
)
from fastdeploy.eplb.experts_manager import RedundantExpertManager
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus
from fastdeploy.model_executor.layers.quantization import get_quantization_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger, parse_quantization
@@ -151,6 +159,7 @@ class PaddleDisWorkerProc:
self.fd_config = fd_config
self.parallel_config = fd_config.parallel_config
self.cache_config = fd_config.cache_config
self.eplb_config = fd_config.eplb_config
# TODO(gongshaotian): Use worker factory to get worker
self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
@@ -249,6 +258,18 @@ class PaddleDisWorkerProc:
create=False,
)
def update_weights_from_tensor(self, mmap_infos):
"""
update_weights_from_tensor
"""
state_dicts = load_tensor_from_shm_mem(self.experts_manager.tensor_infos, mmap_infos[MODEL_MAIN_NAME], logger)
rank_expert_list, logical_to_physical_map, expert_count = self.experts_manager.get_ep_rank_to_expert_id_list()
self.worker.get_model().redundant_table_manger.update_expert_rank_table(
rank_expert_list, logical_to_physical_map, expert_count
)
# TO BE FIXED
self.worker.get_model().update_state_dict(state_dicts)
def _broadcast_model_weights_signal(self, src: int, group) -> int:
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
@@ -258,6 +279,63 @@ class PaddleDisWorkerProc:
"""Main event loop for Paddle Distrubuted Workers.
TODO(gongshaotian): support remote calling of functions that control worker.
"""
if self.eplb_config.enable_eplb:
self.last_dump_expert_workload_ts = 0
self.experts_manager = RedundantExpertManager(
rank=self.local_rank,
ep_size=self.ranks,
fd_config=self.fd_config,
ipc_signal_suffix=self.parallel_config.engine_worker_queue_port,
)
experts_token_stats = np.zeros(
(self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.moe_num_experts),
dtype=np.int32,
)
local_experts_token_stats_array = IPCSignal(
name="local_experts_token_stats",
array=experts_token_stats,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
clear_experts_token_stats = np.zeros([1], dtype=np.int32)
signal_clear_experts_token_stats = IPCSignal(
name="signal_clear_experts_token_stats",
array=clear_experts_token_stats,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
if self.local_rank == 0:
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
signal_update_weight_from_tensor_array = IPCSignal(
name="signal_update_weight_from_tensor",
array=signal_update_weight_from_tensor,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
rearrange_experts_status = np.zeros([1], dtype=np.int32)
rearrange_experts_signal = IPCSignal(
name="rearrange_experts_status",
array=rearrange_experts_status,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
mmap_infos = create_mmap(
[MODEL_MAIN_NAME],
self.local_rank,
self.ranks,
shm_uuid=self.parallel_config.engine_worker_queue_port,
eplb_config=self.eplb_config,
logger=logger,
)
# Currently, only support single node
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
req_ids = []
@@ -267,6 +345,45 @@ class PaddleDisWorkerProc:
attention_dp_cached_prefill_tasks = []
attention_dp_wait_prefill_iters = 0
while True:
if self.eplb_config.enable_eplb:
rearrange_time = time.time()
# 获取专家负载
if local_experts_token_stats_array.value is not None and (
int(rearrange_time) - self.last_dump_expert_workload_ts
> self.eplb_config.redundant_expert_dump_workload_interval
):
self.last_dump_expert_workload_ts = int(rearrange_time)
clear_stat = False
if signal_clear_experts_token_stats.value[0] == 1:
clear_stat = True
signal_clear_experts_token_stats.value[0] = 0
(
new_stats_array,
_,
_,
_,
) = self.worker.get_model().redundant_table_manger.get_expert_tokens_stats(clear_stat=clear_stat)
local_experts_token_stats_array.value[:] = new_stats_array[:]
elif local_experts_token_stats_array.value is None:
logger.warning("redundant_expert: local_experts_token_stats not init")
# 所有DP同步更新权重
broadcast_value = 0
if self.local_rank == 0 and signal_update_weight_from_tensor_array.value[0] == 1:
logger.info("redundant_expert: update_weight_from_tensor broadcast signal")
signal_update_weight_from_tensor_array.value[0] = 0
broadcast_value = REARRANGE_EXPERT_MAGIC_NUM
data = paddle.to_tensor([broadcast_value])
paddle.distributed.broadcast(data, 0)
if data[0] == REARRANGE_EXPERT_MAGIC_NUM:
self.update_weights_from_tensor(mmap_infos)
logger.info(
f"redundant_expert: update_weight_from_tensor success, cost {(time.time() - rearrange_time)*1000}ms"
)
paddle.distributed.barrier()
if self.local_rank == 0:
rearrange_experts_signal.value[0] = RearrangeExpertStatus.DONE.value
logger.info("redundant_expert: done")
if local_rank == 0:
if self.model_weights_status.value[0] != 0:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
@@ -706,6 +823,13 @@ def parse_args():
help="max waiting steps to sync all dp for prefill tasks available",
)
parser.add_argument(
"--eplb_config",
type=json.loads,
default=None,
help="EPLB Configuration.",
)
args = parser.parse_args()
return args
@@ -764,6 +888,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
early_stop_config = EarlyStopConfig(args.early_stop_config)
eplb_config = EPLBConfig(args.eplb_config)
# Note(tangbinhan): used for load_checkpoint
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
@@ -861,6 +987,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
moba_attention_config=moba_attention_config,
enable_attention_dp_balance=args.enable_attention_dp_balance,
attention_dp_time_out_iters=args.attention_dp_time_out_iters,
eplb_config=eplb_config,
)
update_fd_config_for_mm(fd_config)

View File

@@ -40,3 +40,5 @@ opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
partial_json_parser
einops
cuda-python==12.8
setproctitle