mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support eplb for ep (#4786)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* support eplb for ep * update code * update code * update code * update code * update code * update code * update code * update code * update code
This commit is contained in:
@@ -124,7 +124,6 @@ class ModelConfig:
|
||||
self.max_model_len = 0
|
||||
self.dtype = ""
|
||||
self.enable_logprob = False
|
||||
self.enable_redundant_experts = False
|
||||
self.redundant_experts_num = 0
|
||||
self.seed = 0
|
||||
self.quantization = None
|
||||
@@ -247,6 +246,60 @@ class ModelConfig:
|
||||
logger.info("=============================================================")
|
||||
|
||||
|
||||
class EPLBConfig:
|
||||
"""
|
||||
Configuration for EPLB manager.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
):
|
||||
# enable eplb
|
||||
self.enable_eplb: bool = False
|
||||
# redundant experts num
|
||||
self.redundant_experts_num: int = 0
|
||||
# expert ip shm size
|
||||
self.redundant_expert_ip_shm_size: int = 1024
|
||||
# expert meta dir
|
||||
self.redundant_expert_meta_dir: str = "/tmp/redundant_expert_meta"
|
||||
# expert api user and password
|
||||
self.redundant_expert_api_user: str = ""
|
||||
self.redundant_expert_api_password: str = ""
|
||||
# expert eplb strategy
|
||||
self.redundant_expert_eplb_strategy: str = ""
|
||||
# expert dump workload interval
|
||||
self.redundant_expert_dump_workload_interval: int = 10
|
||||
# expert async load model shmem size gb
|
||||
self.redundant_expert_async_load_model_shmem_size_gb: int = 0
|
||||
# expert enable schedule cordon
|
||||
self.redundant_expert_enable_schedule_cordon: bool = True
|
||||
# model use safetensors
|
||||
self.model_use_safetensors: bool = True
|
||||
# model use offline quant
|
||||
self.model_use_offline_quant: bool = True
|
||||
# moe quant type
|
||||
self.moe_quant_type: str = "w4a8"
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert eplb_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
Print all configuration information.
|
||||
"""
|
||||
logger.info("EPLB Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info("=============================================================")
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
@@ -1141,6 +1194,7 @@ class FDConfig:
|
||||
reasoning_parser: str = None,
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
eplb_config: EPLBConfig = None,
|
||||
early_stop_config: Optional[Dict[str, Any]] = None,
|
||||
tool_parser: str = None,
|
||||
test_mode=False,
|
||||
@@ -1159,6 +1213,7 @@ class FDConfig:
|
||||
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
|
||||
self.decoding_config: DecodingConfig = decoding_config # type: ignore
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
self.eplb_config: Optional[EPLBConfig] = eplb_config
|
||||
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
|
||||
self.enable_attention_dp_balance = enable_attention_dp_balance
|
||||
self.attention_dp_time_out_iters = attention_dp_time_out_iters
|
||||
@@ -1386,6 +1441,7 @@ class FDConfig:
|
||||
or k == "scheduler_config"
|
||||
or k == "parallel_config"
|
||||
or k == "commit_config"
|
||||
or k == "eplb_config"
|
||||
):
|
||||
if v is not None:
|
||||
v.print()
|
||||
|
||||
@@ -26,6 +26,7 @@ from fastdeploy import envs
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
EarlyStopConfig,
|
||||
EPLBConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
@@ -397,6 +398,15 @@ class EngineArgs:
|
||||
Max waiting steps to sync all dp for prefill tasks available
|
||||
"""
|
||||
|
||||
enable_eplb: bool = False
|
||||
"""
|
||||
Flag to enable eplb
|
||||
"""
|
||||
eplb_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Configuration for eplb.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -693,6 +703,18 @@ class EngineArgs:
|
||||
default=EngineArgs.enable_expert_parallel,
|
||||
help="Enable expert parallelism.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--enable-eplb",
|
||||
action="store_true",
|
||||
default=EngineArgs.enable_eplb,
|
||||
help="Enable eplb.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--eplb-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.eplb_config,
|
||||
help="Config of eplb.",
|
||||
)
|
||||
|
||||
# Load group
|
||||
load_group = parser.add_argument_group("Load Configuration")
|
||||
@@ -1022,7 +1044,17 @@ class EngineArgs:
|
||||
early_stop_args[k] = v
|
||||
return EarlyStopConfig(early_stop_args)
|
||||
|
||||
def create_engine_config(self) -> FDConfig:
|
||||
def create_eplb_config(self) -> EPLBConfig:
|
||||
"""
|
||||
Create and retuan an EPLBConfig object based on the current settings.
|
||||
"""
|
||||
eplb_args = asdict(self)
|
||||
if self.eplb_config is not None:
|
||||
for k, v in self.eplb_config.items():
|
||||
eplb_args[k] = v
|
||||
return EPLBConfig(eplb_args)
|
||||
|
||||
def create_engine_config(self, port_availability_check: bool = True) -> FDConfig:
|
||||
"""
|
||||
Create and return a Config object based on the current settings.
|
||||
"""
|
||||
@@ -1063,6 +1095,7 @@ class EngineArgs:
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||
moba_attention_config = self.create_moba_attention_config()
|
||||
eplb_cfg = self.create_eplb_config()
|
||||
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||
@@ -1072,9 +1105,10 @@ class EngineArgs:
|
||||
if isinstance(self.engine_worker_queue_port, str):
|
||||
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
|
||||
|
||||
assert is_port_available(
|
||||
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
|
||||
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
||||
if port_availability_check:
|
||||
assert is_port_available(
|
||||
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
|
||||
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
||||
|
||||
return FDConfig(
|
||||
model_config=model_cfg,
|
||||
@@ -1084,6 +1118,7 @@ class EngineArgs:
|
||||
load_config=load_cfg,
|
||||
parallel_config=parallel_cfg,
|
||||
max_model_len=self.max_model_len,
|
||||
eplb_config=eplb_cfg,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
speculative_config=speculative_cfg,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
|
||||
@@ -33,6 +33,7 @@ from opentelemetry import trace
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
from fastdeploy.eplb.utils import init_eplb_signals
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
@@ -132,6 +133,12 @@ class EngineSevice:
|
||||
)
|
||||
self._init_worker_monitor_signals()
|
||||
|
||||
if self.cfg.eplb_config.enable_eplb:
|
||||
current_suffix = int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
)
|
||||
init_eplb_signals(cfg, current_suffix)
|
||||
|
||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||
|
||||
def start(self):
|
||||
|
||||
@@ -461,6 +461,7 @@ class LLMEngine:
|
||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
||||
f" --attention_dp_time_out_iters {self.cfg.attention_dp_time_out_iters}"
|
||||
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
|
||||
f" --ips {ips}"
|
||||
)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import os
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -26,8 +27,9 @@ from fastdeploy import envs
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||
from fastdeploy.eplb.utils import RedundantExpertWorkload
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||
from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus, ZmqIpcClient
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -49,6 +51,7 @@ class EngineClient:
|
||||
port,
|
||||
limit_mm_per_prompt,
|
||||
mm_processor_kwargs,
|
||||
config,
|
||||
# enable_mm=False,
|
||||
reasoning_parser=None,
|
||||
data_parallel_size=1,
|
||||
@@ -59,6 +62,7 @@ class EngineClient:
|
||||
):
|
||||
import fastdeploy.model_executor.models # noqa: F401
|
||||
|
||||
self.config = config
|
||||
architectures = ModelConfig({"model": model_name_or_path}).architectures[0]
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
if MultimodalRegistry.contains_model(architectures):
|
||||
@@ -92,6 +96,9 @@ class EngineClient:
|
||||
else:
|
||||
self.is_master = False
|
||||
|
||||
if self.config.eplb_config.enable_eplb and self.config.parallel_config.expert_parallel_rank == 0:
|
||||
self.init_eplb_signals(ipc_signal_suffix=port)
|
||||
|
||||
array_size = min(max_chips_per_node, tensor_parallel_size)
|
||||
self.worker_healthy_live_recorded_time_array = np.zeros(shape=[array_size], dtype=np.int32)
|
||||
self.worker_healthy_live_signal = IPCSignal(
|
||||
@@ -115,6 +122,115 @@ class EngineClient:
|
||||
)
|
||||
self.connection_initialized = False
|
||||
|
||||
def init_eplb_signals(self, ipc_signal_suffix):
|
||||
"""
|
||||
Initialize eplb signals.
|
||||
"""
|
||||
self.signal_clear_experts_token_stats_list = []
|
||||
self.local_experts_token_stats_array_list = []
|
||||
self.expert_tokens_stats_array_list = []
|
||||
self.signal_update_weight_from_disk_array_list = []
|
||||
self.update_weight_from_disk_result_list = []
|
||||
rearrange_experts_status = np.zeros([1], dtype=np.int32)
|
||||
self.rearrange_experts_signal = IPCSignal(
|
||||
name="rearrange_experts_status",
|
||||
array=rearrange_experts_status,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32)
|
||||
self.rearrange_experts_ips_size_signal = IPCSignal(
|
||||
name="rearrange_experts_ips_size",
|
||||
array=rearrange_experts_ips_size_array,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
self.shm_rearrange_experts_ips_list = IPCSignal(
|
||||
name="rearrange_experts_ips_list",
|
||||
shm_size=self.config.eplb_config.redundant_expert_ip_shm_size,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
|
||||
self.signal_update_weight_from_tensor_array = IPCSignal(
|
||||
name="signal_update_weight_from_tensor",
|
||||
array=signal_update_weight_from_tensor,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
if envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
engine_worker_suffix = [
|
||||
self.config.parallel_config.engine_worker_queue_port[
|
||||
self.config.parallel_config.local_data_parallel_id
|
||||
]
|
||||
]
|
||||
else:
|
||||
engine_worker_suffix = self.config.parallel_config.engine_worker_queue_port
|
||||
|
||||
for suffix_port in engine_worker_suffix:
|
||||
signal_clear_experts_token_stats = np.zeros([1], dtype=np.int32)
|
||||
self.signal_clear_experts_token_stats_list.append(
|
||||
IPCSignal(
|
||||
name="signal_clear_experts_token_stats",
|
||||
array=signal_clear_experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=suffix_port,
|
||||
create=False,
|
||||
)
|
||||
)
|
||||
|
||||
signal_update_weight_from_disk = np.zeros([1], dtype=np.int32)
|
||||
self.signal_update_weight_from_disk_array_list.append(
|
||||
IPCSignal(
|
||||
name="signal_update_weight_from_disk",
|
||||
array=signal_update_weight_from_disk,
|
||||
dtype=np.int32,
|
||||
suffix=suffix_port,
|
||||
create=False,
|
||||
)
|
||||
)
|
||||
|
||||
result_update_weight_from_disk = np.zeros([1], dtype=np.int32)
|
||||
self.update_weight_from_disk_result_list.append(
|
||||
IPCSignal(
|
||||
name="result_update_weight_from_disk",
|
||||
array=result_update_weight_from_disk,
|
||||
dtype=np.int32,
|
||||
suffix=suffix_port,
|
||||
create=False,
|
||||
)
|
||||
)
|
||||
|
||||
experts_token_stats = np.zeros(
|
||||
(self.config.model_config.num_hidden_layers, self.config.model_config.moe_num_experts),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.expert_tokens_stats_array_list.append(
|
||||
IPCSignal(
|
||||
name="all_experts_token_stats",
|
||||
array=experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=suffix_port,
|
||||
create=False,
|
||||
)
|
||||
)
|
||||
self.local_experts_token_stats_array_list.append(
|
||||
IPCSignal(
|
||||
name="local_experts_token_stats",
|
||||
array=experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=suffix_port,
|
||||
create=False,
|
||||
)
|
||||
)
|
||||
|
||||
def create_zmq_client(self, model, mode):
|
||||
"""
|
||||
Create a ZMQ client.
|
||||
@@ -394,3 +510,209 @@ class EngineClient:
|
||||
|
||||
def check_model_weight_status(self):
|
||||
return self.model_weights_status_signal.value[0] < 0
|
||||
|
||||
async def rearrange_experts(self, request_dict: dict):
|
||||
"""
|
||||
rearrange experts
|
||||
Args:
|
||||
request_dict (dict): request body
|
||||
Returns:
|
||||
tuple: response body, status code
|
||||
"""
|
||||
content, status_code = None, HTTPStatus.OK
|
||||
eplb_config = self.config.eplb_config
|
||||
|
||||
if not eplb_config.enable_eplb:
|
||||
content = {"code": 1, "msg": "redundant expert is disabled"}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
return content, status_code
|
||||
|
||||
if (
|
||||
request_dict.get("user", "") != eplb_config.redundant_expert_api_user
|
||||
or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password
|
||||
):
|
||||
content = {"code": 1, "msg": "user or passwd is invalid"}
|
||||
status_code = HTTPStatus.UNAUTHORIZED
|
||||
return content, status_code
|
||||
|
||||
if self.config.parallel_config.expert_parallel_rank != 0:
|
||||
content = {
|
||||
"code": 1,
|
||||
"msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0",
|
||||
}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
return content, status_code
|
||||
|
||||
action = request_dict.get("action", "")
|
||||
api_server_logger.info(f"redundant_expert: rearrange_experts recv request, action {action}")
|
||||
if action == "":
|
||||
# action: start rearrange experts
|
||||
# params: {'user': 'xxx', 'passwd': 'xxx', 'ips': ['10.54.99.77:8000', '10.54.99.77:8300']}
|
||||
if self.rearrange_experts_signal.value[0] != RearrangeExpertStatus.FREE.value:
|
||||
content = {
|
||||
"code": 1,
|
||||
"msg": f"rearrange is doing. actual status {self.rearrange_experts_signal.value[0]}, expect status {RearrangeExpertStatus.FREE.value}",
|
||||
}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
if "ips" not in request_dict and content is None:
|
||||
content = {"code": 1, "msg": "ips in request is None"}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
|
||||
if content is not None:
|
||||
return content, status_code
|
||||
|
||||
data_bytes = (";".join(request_dict["ips"])).encode("utf-8")
|
||||
data_size = len(data_bytes)
|
||||
if data_size > eplb_config.redundant_expert_ip_shm_size:
|
||||
content = {
|
||||
"code": 1,
|
||||
"msg": f"actual ips size {data_size}, max limit {eplb_config.redundant_expert_ip_shm_size}",
|
||||
}
|
||||
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
else:
|
||||
self.rearrange_experts_ips_size_signal.value[0] = data_size
|
||||
self.shm_rearrange_experts_ips_list.shm.buf[:data_size] = data_bytes
|
||||
content = {"code": 0, "msg": "ok"}
|
||||
status_code = HTTPStatus.OK
|
||||
return content, status_code
|
||||
elif action == "recv_expert_weight":
|
||||
# action: receive global expert workload, and begin update weight from disk
|
||||
# params: {'user': 'xxx', 'passwd': 'xxx', 'weight': (layers, experts)}
|
||||
if "data" not in request_dict or not isinstance(request_dict["data"], list):
|
||||
content = {"code": 1, "msg": "data not in request or data is not a list"}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
|
||||
elif len(request_dict["data"]) != len(self.expert_tokens_stats_array_list):
|
||||
content = {
|
||||
"code": 1,
|
||||
"msg": f"actual data length {len(request_dict['data'])}, expect length {len(self.expert_tokens_stats_array_list)}",
|
||||
}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
else:
|
||||
weight = np.array(request_dict["data"], dtype=np.int32)
|
||||
for idx in range(len(self.expert_tokens_stats_array_list)):
|
||||
self.expert_tokens_stats_array_list[idx].value[:] = weight[:]
|
||||
self.signal_update_weight_from_disk_array_list[idx].value[0] = 1
|
||||
|
||||
content = {"code": 0, "msg": "ok"}
|
||||
status_code = HTTPStatus.OK
|
||||
return content, status_code
|
||||
elif action == "update_weight_from_tensor":
|
||||
if self.cfg.scheduler_config.splitwise_role != "prefill" and content is None:
|
||||
content = {
|
||||
"code": 1,
|
||||
"msg": f"actual role {self.cfg.scheduler_config.splitwise_role}, expect role prefill",
|
||||
}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
if self.rearrange_experts_signal.value[0] != RearrangeExpertStatus.LOAD_SUCC.value and content is None:
|
||||
content = {
|
||||
"code": 1,
|
||||
"msg": f"actual status {self.rearrange_experts_signal.value[0]}, expect status {RearrangeExpertStatus.LOAD_SUCC.value}",
|
||||
}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
|
||||
if content is None:
|
||||
self.signal_update_weight_from_tensor_array.value[0] = 1
|
||||
content = {"code": 0, "msg": "ok"}
|
||||
status_code = HTTPStatus.OK
|
||||
return content, status_code
|
||||
else:
|
||||
content = {"code": 1, "msg": f"invalid action {action}"}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
return content, status_code
|
||||
|
||||
async def get_per_expert_tokens_stats(self, request_dict: dict):
|
||||
"""
|
||||
get per expert tokens stats
|
||||
|
||||
Args:
|
||||
request_dict (dict): request body
|
||||
Returns:
|
||||
tuple: response body, status code
|
||||
"""
|
||||
content, status_code = None, HTTPStatus.OK
|
||||
eplb_config = self.config.eplb_config
|
||||
|
||||
if not eplb_config.enable_eplb:
|
||||
content = {"code": 1, "msg": "redundant expert is disabled"}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
return content, status_code
|
||||
|
||||
if (
|
||||
request_dict.get("user", "") != eplb_config.redundant_expert_api_user
|
||||
or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password
|
||||
):
|
||||
content = {"code": 1, "msg": "user or passwd is invalid"}
|
||||
status_code = HTTPStatus.UNAUTHORIZED
|
||||
return content, status_code
|
||||
|
||||
if self.config.parallel_config.expert_parallel_rank != 0:
|
||||
content = {
|
||||
"code": 1,
|
||||
"msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0",
|
||||
}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
return content, status_code
|
||||
|
||||
if "clear_stat" in request_dict and request_dict["clear_stat"]:
|
||||
for clear_experts_token_stats in self.signal_clear_experts_token_stats_list:
|
||||
clear_experts_token_stats.value[0] = 1
|
||||
|
||||
local_experts_list = []
|
||||
for local_experts_token_stats in self.local_experts_token_stats_array_list:
|
||||
local_experts_list.append(local_experts_token_stats.value.tolist())
|
||||
content = {"code": 0, "msg": "ok", "data": local_experts_list}
|
||||
status_code = HTTPStatus.OK
|
||||
return content, status_code
|
||||
|
||||
async def check_redundant(self, request_dict: dict):
|
||||
"""
|
||||
check redundant
|
||||
Args:
|
||||
request_dict (dict): request body
|
||||
Returns:
|
||||
tuple: response body, status code
|
||||
"""
|
||||
content, status_code = None, HTTPStatus.OK
|
||||
eplb_config = self.config.eplb_config
|
||||
|
||||
if not eplb_config.enable_eplb:
|
||||
content = {"code": 1, "msg": "redundant expert is disabled"}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
return content, status_code
|
||||
|
||||
if (
|
||||
request_dict.get("user", "") != eplb_config.redundant_expert_api_user
|
||||
or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password
|
||||
):
|
||||
content = {"code": 1, "msg": "user or passwd is invalid"}
|
||||
status_code = HTTPStatus.UNAUTHORIZED
|
||||
return content, status_code
|
||||
|
||||
if self.config.parallel_config.expert_parallel_rank != 0:
|
||||
content = {
|
||||
"code": 1,
|
||||
"msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0",
|
||||
}
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
return content, status_code
|
||||
|
||||
action = request_dict.get("action", "")
|
||||
if action == "":
|
||||
status = "unknown"
|
||||
try:
|
||||
status = RearrangeExpertStatus(self.rearrange_experts_signal.value[0]).name
|
||||
except:
|
||||
pass
|
||||
content = {"code": 0, "msg": "ok", "status": status}
|
||||
get_workloads = False if "check_get_workloads" not in request_dict else request_dict["check_get_workloads"]
|
||||
if get_workloads:
|
||||
content["data"], content["msg"] = RedundantExpertWorkload(eplb_config.redundant_expert_meta_dir).load()
|
||||
status_code = HTTPStatus.OK
|
||||
elif action == "check_load_weight_result":
|
||||
update_weight_from_disk_list = []
|
||||
for update_weight_result in self.update_weight_from_disk_result_list:
|
||||
update_weight_from_disk_list.append(update_weight_result.value[0].tolist())
|
||||
content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list}
|
||||
status_code = HTTPStatus.OK
|
||||
return content, status_code
|
||||
|
||||
@@ -155,6 +155,8 @@ async def lifespan(app: FastAPI):
|
||||
verification = False
|
||||
model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)]
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
config = engine_args.create_engine_config(port_availability_check=False)
|
||||
engine_client = EngineClient(
|
||||
model_name_or_path=args.model,
|
||||
tokenizer=args.tokenizer,
|
||||
@@ -171,6 +173,7 @@ async def lifespan(app: FastAPI):
|
||||
workers=args.workers,
|
||||
tool_parser=args.tool_call_parser,
|
||||
enable_prefix_caching=args.enable_prefix_caching,
|
||||
config=config,
|
||||
)
|
||||
await engine_client.connection_manager.initialize()
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
@@ -408,6 +411,36 @@ def clear_load_weight(request: Request) -> Response:
|
||||
return Response(content="Dynamic Load Weight Disabled.", status_code=404)
|
||||
|
||||
|
||||
@app.post("/rearrange_experts")
|
||||
async def rearrange_experts(request: Request):
|
||||
"""
|
||||
rearrange experts
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
content, status_code = await app.state.engine_client.rearrange_experts(request_dict=request_dict)
|
||||
return JSONResponse(content, status_code=status_code)
|
||||
|
||||
|
||||
@app.post("/get_per_expert_tokens_stats")
|
||||
async def get_per_expert_tokens_stats(request: Request):
|
||||
"""
|
||||
get per expert tokens stats
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
content, status_code = await app.state.engine_client.get_per_expert_tokens_stats(request_dict=request_dict)
|
||||
return JSONResponse(content, status_code=status_code)
|
||||
|
||||
|
||||
@app.post("/check_redundant")
|
||||
async def check_redundant(request: Request):
|
||||
"""
|
||||
check redundant
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
content, status_code = await app.state.engine_client.check_redundant(request_dict=request_dict)
|
||||
return JSONResponse(content, status_code=status_code)
|
||||
|
||||
|
||||
def launch_api_server() -> None:
|
||||
"""
|
||||
启动http服务
|
||||
|
||||
15
fastdeploy/eplb/__init__.py
Normal file
15
fastdeploy/eplb/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
427
fastdeploy/eplb/async_expert_loader.py
Normal file
427
fastdeploy/eplb/async_expert_loader.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import EPLBConfig
|
||||
|
||||
REARRANGE_EXPERT_MAGIC_NUM = 147183647
|
||||
REARRANGE_ORIGINATOR_EP_RANK = 0
|
||||
CHECK_TIME_INTERNAL = 3
|
||||
HTTP_RETRY_NUM = 5
|
||||
CHECK_TIMEOUT = 120
|
||||
|
||||
libc = ctypes.CDLL(None)
|
||||
|
||||
libc.mmap.argtypes = [
|
||||
ctypes.c_void_p, # void *addr
|
||||
ctypes.c_size_t, # size_t length
|
||||
ctypes.c_int, # int prot
|
||||
ctypes.c_int, # int flags
|
||||
ctypes.c_int, # int fd
|
||||
ctypes.c_size_t, # off_t offset
|
||||
]
|
||||
libc.mmap.restype = ctypes.c_void_p
|
||||
libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
|
||||
libc.munmap.restype = ctypes.c_int
|
||||
|
||||
PROT_READ = 0x1
|
||||
PROT_WRITE = 0x2
|
||||
MAP_SHARED = 0x01
|
||||
MAP_ANONYMOUS = 0x20
|
||||
MAP_FAILED = -1
|
||||
|
||||
G = 1024**3
|
||||
TOTAL_MODEL_SIZE = 350
|
||||
MAIN_MODEL_REDUNDANT_SHM_SIZE = 5
|
||||
|
||||
MODEL_MAIN_NAME = "eplb_main"
|
||||
|
||||
|
||||
def create_mmap(model_name: List, ep_rank: int, ep_size: int, shm_uuid: str, eplb_config: EPLBConfig, logger=None):
|
||||
"""create_mmap"""
|
||||
flags = MAP_SHARED
|
||||
prot = PROT_READ | PROT_WRITE
|
||||
|
||||
main_size = 0
|
||||
if eplb_config.redundant_expert_async_load_model_shmem_size_gb == 0:
|
||||
main_size = TOTAL_MODEL_SIZE // ep_size
|
||||
else:
|
||||
main_size = eplb_config.redundant_expert_async_load_model_shmem_size_gb
|
||||
main_size = main_size * G
|
||||
|
||||
mmap_infos = {}
|
||||
|
||||
from cuda import cudart
|
||||
|
||||
for name in model_name:
|
||||
expert_weight_file = f"/dev/shm/{name}_rank_{ep_rank}_expert_weight_{shm_uuid}"
|
||||
shm_size = main_size
|
||||
|
||||
if not os.path.isfile(expert_weight_file):
|
||||
open(expert_weight_file, "wb").close()
|
||||
shm_fd = os.open(expert_weight_file, os.O_RDWR)
|
||||
os.ftruncate(shm_fd, shm_size)
|
||||
if logger is not None:
|
||||
logger.info(f"redundant_expert: create_mmap file {expert_weight_file}, fd {shm_fd}, size {shm_size}")
|
||||
|
||||
shm_ptr = libc.mmap(0, ctypes.c_size_t(shm_size), prot, flags, shm_fd, 0)
|
||||
if shm_ptr == MAP_FAILED:
|
||||
raise OSError(f"redundant_expert: mmap {expert_weight_file} failed: {ctypes.get_errno()}")
|
||||
|
||||
shm_ptr = ctypes.cast(shm_ptr, ctypes.POINTER(ctypes.c_int8))
|
||||
addr = ctypes.addressof(shm_ptr.contents)
|
||||
|
||||
# Register memory with CUDA
|
||||
(ret,) = cudart.cudaHostRegister(addr, shm_size, 0)
|
||||
if ret != cudart.cudaError_t.cudaSuccess:
|
||||
raise RuntimeError(
|
||||
f"cudaHostRegister failed: {cudart.cudaGetErrorString(ret)}, "
|
||||
f" address {hex(addr)} size {shm_size}, ret: {ret}"
|
||||
)
|
||||
|
||||
mmap_infos[name] = shm_ptr
|
||||
|
||||
return mmap_infos
|
||||
|
||||
|
||||
def save_tensor_to_shm_mem(cached_weights, file_path, logger=None):
|
||||
"""save_tensor_to_shm_mem"""
|
||||
tensor_infos = []
|
||||
offset = 0
|
||||
if not os.path.exists(file_path):
|
||||
raise OSError("File is not exist")
|
||||
|
||||
shm_size = os.path.getsize(file_path)
|
||||
|
||||
for name, w in cached_weights:
|
||||
size = w.numel().item() * w.element_size()
|
||||
# logger.info(f"redundant_expert: save tensor to {name} offset: {offset} size: {size}")
|
||||
w_ptr = ctypes.string_at(w.data_ptr(), size)
|
||||
with open(file_path, "r+b") as file:
|
||||
file.seek(offset)
|
||||
if offset + size > shm_size:
|
||||
raise IOError(
|
||||
f"redundant_expert: Exceeded {file_path} file's size. "
|
||||
+ "Should set a bigger value using env variable."
|
||||
)
|
||||
n = file.write(w_ptr)
|
||||
assert n == size
|
||||
tensor_infos.append((name, offset, size, w.shape, w.dtype))
|
||||
|
||||
offset += size
|
||||
|
||||
sz = offset / 1024 / 1024 / 1024
|
||||
if logger is not None:
|
||||
logger.info(f"redundant_expert: save_tensor_to_shm_mem success. file {file_path} size {sz}G")
|
||||
|
||||
return tensor_infos
|
||||
|
||||
|
||||
def load_tensor_from_shm_mem(tensor_infos, shm_ptr, logger=None):
|
||||
"""load_tensor_from_shm_mem"""
|
||||
# weights_dict = {}
|
||||
weights_dict = []
|
||||
for name, offset, size, shape, dtype in tensor_infos:
|
||||
# 计算共享内存中张量的地址
|
||||
w_addr = ctypes.cast(shm_ptr, ctypes.c_void_p).value + offset
|
||||
w_ptr = ctypes.cast(w_addr, ctypes.POINTER(ctypes.c_byte))
|
||||
# 先读取为字节数组,再通过视图转换成适当类型
|
||||
np_array = np.ctypeslib.as_array(w_ptr, shape=(size,))
|
||||
|
||||
if dtype == paddle.float32:
|
||||
tmp = np_array.view(np.float32)
|
||||
tensor = paddle.Tensor(tmp, dtype=paddle.float32, place=paddle.CPUPlace(), zero_copy=True)
|
||||
elif dtype == paddle.uint8:
|
||||
tmp = np_array.view(np.uint8)
|
||||
tensor = paddle.Tensor(tmp, dtype=paddle.uint8, place=paddle.CPUPlace(), zero_copy=True)
|
||||
elif dtype == paddle.int8:
|
||||
tmp = np_array.view(np.int8)
|
||||
tensor = paddle.Tensor(tmp, dtype=paddle.int8, place=paddle.CPUPlace(), zero_copy=True)
|
||||
elif dtype == paddle.bfloat16:
|
||||
# NumPy 不支持 bfloat16,因此先以 uint16 读取原始数据,再用 Paddle cast 为 bfloat16
|
||||
tmp = np_array.view(np.uint16)
|
||||
tensor = paddle.Tensor(tmp, dtype=paddle.bfloat16, place=paddle.CPUPlace(), zero_copy=True)
|
||||
else:
|
||||
raise TypeError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
assert w_addr == tensor.data_ptr()
|
||||
# weights_dict[name] = tensor.view(shape)
|
||||
weights_dict.append((name, tensor.view(shape)))
|
||||
|
||||
if logger is not None:
|
||||
logger.info("redundant_expert: load_tensor_from_shm_mem succ")
|
||||
return weights_dict
|
||||
|
||||
|
||||
class AsyncEPLoader(object):
|
||||
"""Aynsc Expert loader"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir,
|
||||
eplb_config,
|
||||
rank=8,
|
||||
expert_per_rank=8,
|
||||
moe_layer_start_index=3,
|
||||
moe_quant_type="",
|
||||
logger=None,
|
||||
):
|
||||
"""
|
||||
__init__
|
||||
"""
|
||||
self.model_path = model_dir
|
||||
self.eplb_config = eplb_config
|
||||
|
||||
self.expert_per_rank = expert_per_rank
|
||||
self.moe_layer_start_index = moe_layer_start_index
|
||||
self.ep_rank = rank
|
||||
self.moe_quant_type = moe_quant_type
|
||||
|
||||
self.old_model_ep_rank_to_expert_id_list = None
|
||||
self.new_model_ep_rank_to_expert_id_list = None
|
||||
|
||||
self.cached_weights = []
|
||||
# self.state_dicts = {}
|
||||
self.moe_file_names = []
|
||||
|
||||
self.logger = logger
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
reset
|
||||
"""
|
||||
self.old_model_ep_rank_to_expert_id_list = None
|
||||
self.new_model_ep_rank_to_expert_id_list = None
|
||||
self.cached_weights = []
|
||||
self.moe_file_names = []
|
||||
|
||||
def load_experts_weight_from_disk(self):
|
||||
"""
|
||||
return value: (all_succ whether_load_weight exist_fatal_error message),
|
||||
exist_fatal_error means all rank need restart
|
||||
"""
|
||||
ep_rank = self.ep_rank
|
||||
start_idx = ep_rank * self.expert_per_rank
|
||||
end_idx = start_idx + self.expert_per_rank
|
||||
try:
|
||||
old_expert_ids_all = self.old_model_ep_rank_to_expert_id_list[:, start_idx:end_idx]
|
||||
new_expert_ids_all = self.new_model_ep_rank_to_expert_id_list[:, start_idx:end_idx]
|
||||
need_to_reload = list()
|
||||
for layer_id in range(len(old_expert_ids_all)):
|
||||
if layer_id < self.moe_layer_start_index:
|
||||
continue
|
||||
new_expert_ids = new_expert_ids_all[layer_id]
|
||||
old_expert_ids = old_expert_ids_all[layer_id]
|
||||
if len(new_expert_ids) != len(old_expert_ids):
|
||||
message = f"redundant_expert: new_expert_ids length not equal to old_expert_ids \
|
||||
length layer_id: {layer_id}"
|
||||
# this is very dangerous and unepxpected, should be fixed
|
||||
return False, message
|
||||
# TODO: 按需加载,过滤重复专家
|
||||
self.logger.info(
|
||||
f"redundant_expert: rank {ep_rank} layer {layer_id} old_experts {old_expert_ids}"
|
||||
+ f" new_experts {new_expert_ids}"
|
||||
)
|
||||
need_to_reload.extend([(layer_id, expert_id) for expert_id in new_expert_ids])
|
||||
|
||||
succ = True
|
||||
message = ""
|
||||
if len(need_to_reload) > 0:
|
||||
if self.eplb_config.model_use_safetensors:
|
||||
succ, message = self.load_safetensor_fp8_from_disk(need_to_reload)
|
||||
else:
|
||||
succ, message = self.load_weight_bf16_from_disk(need_to_reload)
|
||||
if not succ:
|
||||
self.logger.info(
|
||||
f"redundant_expert: load_experts_weight_from_disk fail. rank {ep_rank}, error: {message}"
|
||||
)
|
||||
new_message = f"redundant_expert: load_experts_weight_from_disk fail. rank {ep_rank}, error: {message}"
|
||||
return False, new_message
|
||||
self.logger.info(f"redundant_expert: load_experts_weight_from_disk success. rank {ep_rank}")
|
||||
return True, "redundant_expert: load_experts_weight_from_disk success"
|
||||
except Exception as e:
|
||||
message = f"redundant_expert: Failed to load_experts_weight_from_disk ep_rank {ep_rank} excep: {e}"
|
||||
error_message = traceback.format_exc()
|
||||
self.logger.error(f"redundant_expert: message: {message} traceback: {error_message}")
|
||||
return False, message
|
||||
|
||||
def load_weight_bf16_from_disk(self, need_to_reload: List[Tuple[int, int]]):
|
||||
"""load_weight_bf16_from_disk"""
|
||||
try:
|
||||
ckpt_up_gate_proj_name = "up_gate_proj"
|
||||
ckpt_down_proj_name = "down_proj"
|
||||
for layer_id, expert_id in need_to_reload:
|
||||
for weight_name in [ckpt_up_gate_proj_name, ckpt_down_proj_name]:
|
||||
ckpt_file_name = f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{weight_name}.weight"
|
||||
if ckpt_file_name not in self.moe_file_names:
|
||||
self.logger.info(f"record redundant_expert: {ckpt_file_name}")
|
||||
self.moe_file_names.append(ckpt_file_name)
|
||||
|
||||
last_device = paddle.device.get_device()
|
||||
paddle.set_device("cpu")
|
||||
|
||||
for file_name in self.moe_file_names:
|
||||
# 判断文件是否存在
|
||||
if not os.path.exists(self.model_path + "/merged_tp1_state_split/" + file_name):
|
||||
# self.logger.info(f"redundant_expert: {file_name} not exist.")
|
||||
continue
|
||||
# self.logger.info(f"redundant_expert: Loading expert weights: {file_name}.")
|
||||
self.state_dicts[file_name] = paddle.load(self.model_path + "/merged_tp1_state_split/" + file_name)
|
||||
|
||||
paddle.set_device(last_device)
|
||||
self.logger.info("redundant_expert: Loading expert weights end.")
|
||||
return True, "redundant_expert: Succeeded to loading expert weights."
|
||||
except Exception as e:
|
||||
message = f"redundant_expert: Failed to get weights iterator: {e}."
|
||||
return False, message
|
||||
|
||||
def load_safetensor_fp8_from_disk(self, need_to_reload: List[Tuple[int, int]]):
|
||||
"""load_safetensor_fp8_from_disk"""
|
||||
"""
|
||||
ernie.layers.52.mlp.experts.58.up_gate_proj.quant_weight
|
||||
ernie.layers.52.mlp.experts.58.up_gate_proj.weight_scale
|
||||
ernie.layers.52.mlp.experts.58.down_proj.quant_weight
|
||||
ernie.layers.52.mlp.experts.58.down_proj.weight_scale
|
||||
"""
|
||||
up_gate_down = ["up_gate_proj", "down_proj"]
|
||||
quant_weight_scale = ["quant_weight", "weight_scale"]
|
||||
if self.moe_quant_type == "w4a8":
|
||||
quant_weight_scale = ["quant_weight"]
|
||||
ckpt_name = [
|
||||
(f"ernie.layers.{layer_id}.mlp.experts.{expert_id}.{proj_name}.{quant_name}")
|
||||
for layer_id, expert_id in need_to_reload
|
||||
for proj_name in up_gate_down
|
||||
for quant_name in quant_weight_scale
|
||||
]
|
||||
ckpt_name_to_safetensor_file = load_ep_checkpoint(self.model_path)
|
||||
hf_weights_files = list(set(ckpt_name_to_safetensor_file.values()))
|
||||
state_dicts = {}
|
||||
|
||||
last_device = paddle.device.get_device()
|
||||
paddle.set_device("cpu")
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
for st_file in hf_weights_files:
|
||||
with safe_open(st_file, framework="np", device="cpu") as f:
|
||||
for name in f.keys():
|
||||
if name in ckpt_name:
|
||||
weight = f.get_tensor(name)
|
||||
state_dicts[name] = paddle.Tensor(weight, zero_copy=True)
|
||||
weights_list = []
|
||||
for name in ckpt_name:
|
||||
weights_list.append((name, state_dicts[name]))
|
||||
self.cached_weights = weights_list
|
||||
|
||||
paddle.set_device(last_device)
|
||||
return True, "load_expert_weight_from_disk_safetensor success"
|
||||
|
||||
|
||||
def load_ep_checkpoint(model_path):
|
||||
"""
|
||||
load ep checkpoint
|
||||
"""
|
||||
file_path = os.path.join(model_path, "model.safetensors.index.json")
|
||||
if not os.path.exists(file_path):
|
||||
return {}
|
||||
import json
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
state_dict = {k: os.path.join(model_path, v) for k, v in weight_map.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_model_weights_process(
|
||||
rank: int,
|
||||
model_dir: str,
|
||||
expert_per_rank: int,
|
||||
moe_layer_start_index: int,
|
||||
moe_quant_type: str,
|
||||
shm_uuid: str,
|
||||
eplb_config: EPLBConfig,
|
||||
data_conn,
|
||||
mg_conn,
|
||||
):
|
||||
"""
|
||||
load_model_weights_process
|
||||
"""
|
||||
import faulthandler
|
||||
|
||||
from setproctitle import setproctitle
|
||||
|
||||
setproctitle(f"eplb::async_load_model_{rank}")
|
||||
faulthandler.enable()
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("eplb_async_loader", "eplb_{0}.log".format(rank))
|
||||
logger.info("redundant_expert: load_model_weights_process start")
|
||||
|
||||
paddle.set_device("cpu")
|
||||
ep_loader = AsyncEPLoader(
|
||||
model_dir=model_dir,
|
||||
rank=rank,
|
||||
expert_per_rank=expert_per_rank,
|
||||
moe_layer_start_index=moe_layer_start_index,
|
||||
moe_quant_type=moe_quant_type,
|
||||
logger=logger,
|
||||
eplb_config=eplb_config,
|
||||
)
|
||||
|
||||
while True:
|
||||
ep_loader.reset()
|
||||
data = mg_conn.recv()
|
||||
|
||||
result = True
|
||||
weight_infos = []
|
||||
try:
|
||||
ep_loader.old_model_ep_rank_to_expert_id_list = data["old_model_ep_rank_to_expert_id_list"]
|
||||
ep_loader.new_model_ep_rank_to_expert_id_list = data["new_model_ep_rank_to_expert_id_list"]
|
||||
|
||||
begin_time_disk = int(time.time())
|
||||
success, message = ep_loader.load_experts_weight_from_disk()
|
||||
begin_time_shm = int(time.time())
|
||||
logger.info(
|
||||
"redundant_expert: async load load_weight_from_disk, "
|
||||
+ f"succ {success}, cost {begin_time_shm-begin_time_disk}s"
|
||||
)
|
||||
if success:
|
||||
model_name = MODEL_MAIN_NAME
|
||||
file_path = f"/dev/shm/{model_name}_rank_{rank}_expert_weight_{shm_uuid}"
|
||||
weight_infos = save_tensor_to_shm_mem(ep_loader.cached_weights, file_path, logger)
|
||||
logger.info(
|
||||
"redundant_expert: async load save_tensor_to_shm_mem, "
|
||||
+ f"tensor nums {len(weight_infos)}, cost {int(time.time()-begin_time_shm)}s"
|
||||
)
|
||||
else:
|
||||
logger.error(f"redundant_expert: async load load_weight_from_disk failed, error {message}")
|
||||
result = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"redundant_expert: async load weights failed, rank {rank} error {e}")
|
||||
result = False
|
||||
weight_infos = []
|
||||
finally:
|
||||
request_data = {"result": result, "weights": weight_infos}
|
||||
data_conn.send(request_data)
|
||||
291
fastdeploy/eplb/eplb.py
Normal file
291
fastdeploy/eplb/eplb.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def balanced_packing(weight: np.ndarray, num_packs: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
|
||||
are as balanced as possible.
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = np.arange(weight.shape[-1], dtype=np.int32).reshape(1, -1).repeat(num_layers, axis=0)
|
||||
rank_in_pack = np.zeros_like(weight, dtype=np.int32)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
indices = np.argsort(-weight.astype(np.float32), axis=-1)
|
||||
pack_index = np.full_like(weight, fill_value=-1, dtype=np.int32)
|
||||
rank_in_pack = np.full_like(pack_index, fill_value=-1)
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
for group in indices[i]:
|
||||
pack = min(
|
||||
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index[i, group] = pack
|
||||
rank_in_pack[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight[i, group]
|
||||
pack_items[pack] += 1
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
|
||||
def replicate_experts(weight: np.ndarray, num_phy: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the replica rank
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
phy2log = np.arange(num_phy, dtype=np.int32).reshape(1, -1).repeat(n, axis=0)
|
||||
rank = np.zeros((n, num_phy), dtype=np.int32)
|
||||
logcnt = np.ones((n, num_log), dtype=np.int32)
|
||||
arangen = np.arange(n, dtype=np.int32)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = np.argmax(weight / logcnt, axis=-1)
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts_intra_node(
|
||||
weight: np.ndarray,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
Returns:
|
||||
physical_to_logical_map: [num_moe_layers, num_physical_experts]
|
||||
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
|
||||
logical_count: [num_moe_layers, num_logical_experts]
|
||||
"""
|
||||
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
num_redundant_experts = num_physical_experts - num_logical_experts
|
||||
assert num_redundant_experts >= 0
|
||||
|
||||
assert num_gpus % num_nodes == 0
|
||||
num_gpus_per_node = num_gpus // num_nodes
|
||||
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
num_physical_experts_per_gpu = num_physical_experts // num_gpus
|
||||
assert num_physical_experts % num_nodes == 0
|
||||
num_physical_experts_per_node = num_physical_experts // num_nodes
|
||||
|
||||
assert num_logical_experts % num_physical_experts_per_node == 0
|
||||
# num_logical_nodes = num_logical_experts // num_physical_experts_per_node
|
||||
assert num_redundant_experts % num_physical_experts_per_node == 0
|
||||
# num_redundant_nodes = num_redundant_experts // num_physical_experts_per_node
|
||||
|
||||
def inverse(perm: np.ndarray) -> np.ndarray:
|
||||
inv = np.empty_like(perm)
|
||||
inv[np.arange(perm.shape[0])[:, None], perm] = np.arange(perm.shape[1], dtype=np.int32).reshape(1, -1)
|
||||
return inv
|
||||
|
||||
# Step 1: generate redundant experts by weight.
|
||||
# shape of tmp2log, tmprank is [num_layers, num_physical_experts]
|
||||
# shape of logcnt is [num_layers, num_logical_experts]
|
||||
tmp2log, tmprank, logcnt = replicate_experts(weight, num_physical_experts)
|
||||
|
||||
# Step 2: compute num_tokens of physical experts
|
||||
# shape of tokens_per_tmp is [num_layers * num_nodes, num_physical_experts_per_node]
|
||||
tokens_per_tmp = np.take_along_axis(weight / logcnt, tmp2log, axis=-1).reshape(-1, num_physical_experts_per_node)
|
||||
|
||||
# STEP 3: take load balance of gpu cards in node
|
||||
# shape of gpu_index, rank_in_gpu, tmp2phy, phy2tmp is [num_layers * num_nodes, num_physical_experts_per_node]
|
||||
gpu_index, rank_in_gpu = balanced_packing(tokens_per_tmp, num_gpus_per_node)
|
||||
tmp2phy = gpu_index * num_physical_experts_per_gpu + rank_in_gpu
|
||||
phy2tmp = inverse(tmp2phy)
|
||||
|
||||
# STEP 4: generate final phy2log mapping
|
||||
tmp2log = tmp2log.reshape(-1, num_physical_experts_per_node)
|
||||
tmprank = tmprank.reshape(-1, num_physical_experts_per_node)
|
||||
phy2log = np.take_along_axis(tmp2log, phy2tmp, axis=-1).reshape(-1, num_physical_experts)
|
||||
phyrank = np.take_along_axis(tmprank, phy2tmp, axis=-1).reshape(-1, num_physical_experts)
|
||||
return phy2log, phyrank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts_hierarchical(
|
||||
weight: np.ndarray,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
Returns:
|
||||
physical_to_logical_map: [num_moe_layers, num_physical_experts]
|
||||
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
|
||||
logical_count: [num_moe_layers, num_logical_experts]
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: np.ndarray) -> np.ndarray:
|
||||
inv = np.empty_like(perm)
|
||||
inv[np.arange(perm.shape[0])[:, None], perm] = np.arange(perm.shape[1], dtype=np.int32).reshape(1, -1)
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(axis=-1)
|
||||
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||
log2mlog = (
|
||||
((group_pack_index * groups_per_node + group_rank_in_pack) * group_size)[:, :, None]
|
||||
+ np.arange(group_size, dtype=np.int32)
|
||||
).reshape(num_layers, -1)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=-1).reshape(-1, num_logical_experts // num_nodes)
|
||||
phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=-1)
|
||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=-1) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (
|
||||
pphy2mlog.reshape(num_layers, num_nodes, -1)
|
||||
+ np.arange(0, num_logical_experts, num_logical_experts // num_nodes, dtype=np.int32).reshape(1, -1, 1)
|
||||
).reshape(num_layers, -1)
|
||||
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=-1)
|
||||
pphyrank = np.take_along_axis(phyrank, pphy2phy, axis=-1).reshape(num_layers, -1)
|
||||
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=-1)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
weight: np.ndarray,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
eplb_strategy: str = "",
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of `num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
Returns:
|
||||
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
|
||||
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
|
||||
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.astype(np.float32)
|
||||
if eplb_strategy == "balance_intra_node":
|
||||
phy2log, phyrank, logcnt = rebalance_experts_intra_node(weight, num_replicas, num_groups, num_nodes, num_gpus)
|
||||
else:
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas)
|
||||
maxlogcnt = logcnt.max()
|
||||
log2phy = np.full((num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int32)
|
||||
np.put_along_axis(
|
||||
log2phy.reshape(num_layers, -1)[:, :, None],
|
||||
(phy2log * maxlogcnt + phyrank)[:, :, None],
|
||||
np.arange(num_replicas, dtype=np.int32).reshape(1, -1).repeat(num_layers, axis=0)[:, :, None],
|
||||
axis=1,
|
||||
)
|
||||
return phy2log, log2phy, logcnt
|
||||
|
||||
|
||||
__all__ = ["rebalance_experts"]
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
main
|
||||
"""
|
||||
num_hidden_layers = 3
|
||||
num_expert = 64
|
||||
num_groups = 8
|
||||
|
||||
num_replicas = 64
|
||||
num_nodes = 4
|
||||
num_gpus = 4 * 8
|
||||
|
||||
model_tokens_per_expert_stats_list = np.random.randint(low=1, high=10, size=(num_hidden_layers, num_expert))
|
||||
|
||||
phy2log, phyrank, logcnt = rebalance_experts(
|
||||
model_tokens_per_expert_stats_list,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
)
|
||||
print(phy2log)
|
||||
print(phyrank)
|
||||
print(logcnt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
503
fastdeploy/eplb/experts_manager.py
Normal file
503
fastdeploy/eplb/experts_manager.py
Normal file
@@ -0,0 +1,503 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from multiprocessing import Pipe, Process
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.eplb.async_expert_loader import load_model_weights_process
|
||||
from fastdeploy.eplb.eplb import rebalance_experts
|
||||
from fastdeploy.eplb.utils import RedundantExpertWorkload
|
||||
from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
|
||||
class RedundantExpertManager:
|
||||
"""
|
||||
RedundantExpertManger
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int = 0,
|
||||
ep_size: int = 32,
|
||||
fd_config: FDConfig = None,
|
||||
ipc_signal_suffix: int = 0,
|
||||
):
|
||||
self.logger = get_logger("eplb_expert_manager", "eplb_{0}.log".format(rank))
|
||||
|
||||
self.rank = rank
|
||||
self.ep_size = ep_size
|
||||
self.fd_config = fd_config
|
||||
self.eplb_config = fd_config.eplb_config
|
||||
self.api_user = self.eplb_config.redundant_expert_api_user
|
||||
self.api_passwd = self.eplb_config.redundant_expert_api_password
|
||||
self.num_redundant_experts = self.eplb_config.redundant_experts_num
|
||||
self.num_hidden_layers = self.fd_config.model_config.num_hidden_layers
|
||||
self.num_logical_experts = self.fd_config.model_config.moe_num_experts
|
||||
self.ipc_signal_suffix = ipc_signal_suffix
|
||||
|
||||
self.num_replicas = self.num_logical_experts + self.num_redundant_experts
|
||||
self.num_groups = self.num_logical_experts
|
||||
self.num_nodes = max(ep_size // 8, 1)
|
||||
self.num_gpus = ep_size
|
||||
self.expert_per_rank = self.num_replicas // ep_size
|
||||
assert (
|
||||
self.num_replicas % ep_size == 0
|
||||
), f"num_replicas must be divisible by ep_size, \
|
||||
but got num_replicas = {self.num_replicas}, ep_size = {ep_size}"
|
||||
|
||||
self.model_ep_rank_to_expert_id_list = np.full(
|
||||
(
|
||||
self.num_hidden_layers,
|
||||
self.num_logical_experts + self.num_redundant_experts,
|
||||
),
|
||||
-1,
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.model_expert_id_to_ep_rank_array = np.full(
|
||||
(
|
||||
self.num_hidden_layers,
|
||||
self.num_logical_experts,
|
||||
self.num_redundant_experts + 1,
|
||||
),
|
||||
-1,
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.model_expert_in_rank_num_list = np.zeros(
|
||||
(self.num_hidden_layers, self.num_logical_experts), dtype=np.int32
|
||||
)
|
||||
|
||||
# backup info
|
||||
self.last_model_ep_rank_to_expert_id_list = np.full(
|
||||
(
|
||||
self.num_hidden_layers,
|
||||
self.num_logical_experts + self.num_redundant_experts,
|
||||
),
|
||||
-1,
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.last_model_expert_id_to_ep_rank_array = np.full(
|
||||
(
|
||||
self.num_hidden_layers,
|
||||
self.num_logical_experts,
|
||||
self.num_redundant_experts + 1,
|
||||
),
|
||||
-1,
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.last_model_expert_in_rank_num_list = np.zeros(
|
||||
(self.num_hidden_layers, self.num_logical_experts), dtype=np.int32
|
||||
)
|
||||
|
||||
self.model_tokens_per_expert_stats_list = np.ones(
|
||||
(self.num_hidden_layers, self.num_logical_experts), dtype=np.int32
|
||||
)
|
||||
self.caculate_expert_rank_table(True)
|
||||
|
||||
self.dp_rank_address = None
|
||||
self.need_allgather_load_weight_result = False
|
||||
self.load_weight_begin_ts = 0
|
||||
self.load_weight_timeout = 300 # 5min
|
||||
self.need_rearrange_expert = False
|
||||
self.need_update_expert_tokens_stat = True
|
||||
self.http_timeout = 1
|
||||
# 重置重排状态: 'done' -> 'free'
|
||||
self.rearrange_end_ts = 0
|
||||
self.rearrange_reset_interval = 300
|
||||
|
||||
self.tensor_infos = None
|
||||
|
||||
self.parent_data_conn, child_data_conn = Pipe()
|
||||
self.parent_mg_conn, child_mg_conn = Pipe()
|
||||
Process(
|
||||
target=load_model_weights_process,
|
||||
name=f"eplb::async_load_model_{rank}",
|
||||
args=(
|
||||
self.rank,
|
||||
self.fd_config.model_config.model,
|
||||
self.expert_per_rank,
|
||||
self.fd_config.model_config.moe_layer_start_index,
|
||||
self.eplb_config.moe_quant_type,
|
||||
self.ipc_signal_suffix,
|
||||
self.eplb_config,
|
||||
child_data_conn,
|
||||
child_mg_conn,
|
||||
),
|
||||
).start()
|
||||
child_data_conn.close()
|
||||
child_mg_conn.close()
|
||||
|
||||
listen_signal_thread = threading.Thread(target=self.listen_rearrange_expert_signal, args=(), daemon=True)
|
||||
listen_signal_thread.start()
|
||||
|
||||
self.logger.info(
|
||||
f"redundant_expert: RedundantExpertManager init success, rank {rank}, \
|
||||
strategy {self.eplb_config.redundant_expert_eplb_strategy}"
|
||||
)
|
||||
|
||||
# def get_unique_name(self, name):
|
||||
# return f"{envs.get_unique_name(name + '_dprank_' + str(self.rank))}"
|
||||
|
||||
def get_ep_rank_to_expert_id_list(self):
|
||||
"""
|
||||
get_ep_rank_to_expert_id_list
|
||||
"""
|
||||
return (
|
||||
self.model_ep_rank_to_expert_id_list,
|
||||
self.model_expert_id_to_ep_rank_array,
|
||||
self.model_expert_in_rank_num_list,
|
||||
)
|
||||
|
||||
def listen_rearrange_expert_signal(self):
|
||||
"""
|
||||
listen_rearrange_expert_signal
|
||||
"""
|
||||
if self.rank == 0:
|
||||
rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32)
|
||||
rearrange_experts_ips_size_signal = IPCSignal(
|
||||
name="rearrange_experts_ips_size",
|
||||
array=rearrange_experts_ips_size_array,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
shm_rearrange_experts_ips_list = IPCSignal(
|
||||
name="rearrange_experts_ips_list",
|
||||
shm_size=self.eplb_config.redundant_expert_ip_shm_size,
|
||||
suffix=self.ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
rearrange_experts_status = np.zeros([1], dtype=np.int32)
|
||||
rearrange_experts_signal = IPCSignal(
|
||||
name="rearrange_experts_status",
|
||||
array=rearrange_experts_status,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
signal_update_weight_from_disk = np.zeros([1], dtype=np.int32)
|
||||
signal_update_weight_from_disk_array = IPCSignal(
|
||||
name="signal_update_weight_from_disk",
|
||||
array=signal_update_weight_from_disk,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
experts_token_stats = np.zeros(
|
||||
(self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.moe_num_experts),
|
||||
dtype=np.int32,
|
||||
)
|
||||
shm_all_experts_token_stats = IPCSignal(
|
||||
name="all_experts_token_stats",
|
||||
array=experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
if self.rank == 0:
|
||||
now = int(time.time())
|
||||
if rearrange_experts_ips_size_signal.value[0] > 0:
|
||||
# step 1. all reduce experts token stats
|
||||
address = bytes(
|
||||
shm_rearrange_experts_ips_list.shm.buf[: rearrange_experts_ips_size_signal.value[0]]
|
||||
).decode("utf-8")
|
||||
self.logger.info(f"redundant_expert: all rank ips {address}")
|
||||
rearrange_experts_ips_size_signal.value[0] = 0
|
||||
rearrange_experts_signal.value[0] = RearrangeExpertStatus.DOING.value
|
||||
|
||||
self.dp_rank_address = address.strip().split(";")
|
||||
if self.allreduce_experts_stat():
|
||||
self.need_allgather_load_weight_result = True
|
||||
self.load_weight_begin_ts = now
|
||||
self.logger.info("redundant_expert: all-reduce experts stats success")
|
||||
else:
|
||||
rearrange_experts_signal.value[0] = RearrangeExpertStatus.FREE.value
|
||||
self.logger.warning("redundant_expert: all-reduce experts stats fail")
|
||||
elif self.need_allgather_load_weight_result and self.allreduce_load_weight_result():
|
||||
# step 3. all reduce the result of load weight from disk
|
||||
self.need_allgather_load_weight_result = False
|
||||
rearrange_experts_signal.value[0] = RearrangeExpertStatus.LOAD_SUCC.value
|
||||
self.rearrange_end_ts = now
|
||||
if rearrange_experts_signal.value[0] > 1 and (
|
||||
now - self.rearrange_end_ts > self.rearrange_reset_interval
|
||||
):
|
||||
# reset rearrange status
|
||||
rearrange_experts_signal.value[0] = RearrangeExpertStatus.FREE.value
|
||||
|
||||
if signal_update_weight_from_disk_array.value[0] == 1:
|
||||
# step 2. async load weight: disk -> memory
|
||||
self.model_tokens_per_expert_stats_list[:] = shm_all_experts_token_stats.value[:]
|
||||
self.caculate_expert_rank_table()
|
||||
self.update_weight_from_disk()
|
||||
signal_update_weight_from_disk_array.value[0] = 0
|
||||
time.sleep(0.5)
|
||||
|
||||
def caculate_expert_rank_table(self, is_init=False):
|
||||
"""
|
||||
caculate_expert_rank_table
|
||||
"""
|
||||
num_groups = self.num_groups
|
||||
num_nodes = self.num_nodes
|
||||
num_gpus = self.num_gpus
|
||||
eplb_strategy = self.eplb_config.redundant_expert_eplb_strategy
|
||||
if is_init:
|
||||
num_groups = 1
|
||||
num_nodes = 2
|
||||
num_gpus = 2 * 8
|
||||
eplb_strategy = ""
|
||||
# eplb
|
||||
rank_expert_list, logical_to_physical_map, expert_count = rebalance_experts(
|
||||
self.model_tokens_per_expert_stats_list,
|
||||
self.num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
eplb_strategy,
|
||||
)
|
||||
|
||||
# backup info
|
||||
self.last_model_ep_rank_to_expert_id_list[:] = self.model_ep_rank_to_expert_id_list[:]
|
||||
self.last_model_expert_id_to_ep_rank_array[:] = self.model_expert_id_to_ep_rank_array[:]
|
||||
self.last_model_expert_in_rank_num_list[:] = self.model_expert_in_rank_num_list[:]
|
||||
|
||||
# update model info
|
||||
self.model_ep_rank_to_expert_id_list[:] = rank_expert_list[:]
|
||||
self.model_expert_id_to_ep_rank_array.fill(-1)
|
||||
self.model_expert_id_to_ep_rank_array[..., : logical_to_physical_map.shape[-1]] = logical_to_physical_map[:]
|
||||
self.model_expert_in_rank_num_list[:] = expert_count[:]
|
||||
|
||||
if self.rank == 0:
|
||||
workload = RedundantExpertWorkload()
|
||||
workload.tokens_per_expert_stats_list = self.model_tokens_per_expert_stats_list.tolist()
|
||||
workload.ep_rank_to_expert_id_list = rank_expert_list.tolist()
|
||||
workload.expert_id_to_ep_rank_array = logical_to_physical_map.tolist()
|
||||
workload.expert_in_rank_num_list = expert_count.tolist()
|
||||
self.logger.info(workload.dump())
|
||||
|
||||
def update_weight_from_disk(self):
|
||||
"""
|
||||
update_weight_from_disk
|
||||
"""
|
||||
begin_time = time.time()
|
||||
result_update_weight_from_disk = np.zeros([1], dtype=np.int32)
|
||||
update_weight_from_disk_result = IPCSignal(
|
||||
name="result_update_weight_from_disk",
|
||||
array=result_update_weight_from_disk,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
update_weight_from_disk_result.value[0] = 0
|
||||
|
||||
self.logger.info(f"redundant_expert: update_weight_from_disk send to async process, rank {self.rank}")
|
||||
self.parent_mg_conn.send(
|
||||
{
|
||||
"old_model_ep_rank_to_expert_id_list": self.last_model_ep_rank_to_expert_id_list,
|
||||
"new_model_ep_rank_to_expert_id_list": self.model_ep_rank_to_expert_id_list,
|
||||
}
|
||||
)
|
||||
self.logger.info(f"redundant_expert: update_weight_from_disk recv from async process, rank {self.rank}")
|
||||
response = self.parent_data_conn.recv()
|
||||
self.tensor_infos = response["weights"]
|
||||
|
||||
# 更新权重加载结果
|
||||
update_weight_from_disk_result.value[0] = 1 if response["result"] else -1
|
||||
self.logger.info(
|
||||
"redundant_expert: update_weight_from_disk end, rank"
|
||||
+ f" {self.rank} {response['result']}, cost {int(time.time() - begin_time)}s"
|
||||
)
|
||||
|
||||
def allreduce_experts_stat(self):
|
||||
"""
|
||||
专家负载
|
||||
"""
|
||||
if not self.allgather_expert_token_stats():
|
||||
return False
|
||||
return self.broadcast_expert_token_stats()
|
||||
|
||||
def allgather_expert_token_stats(self):
|
||||
"""
|
||||
allgather_expert_token_stats
|
||||
"""
|
||||
expert_token_stats = np.zeros((self.num_hidden_layers, self.num_logical_experts), dtype=np.int32)
|
||||
success_count = 0
|
||||
for addr in self.dp_rank_address:
|
||||
try:
|
||||
# TODO: 请求失败重试
|
||||
params = {"user": self.api_user, "passwd": self.api_passwd}
|
||||
res = requests.post(
|
||||
f"http://{addr}/get_per_expert_tokens_stats",
|
||||
json=params,
|
||||
timeout=self.http_timeout,
|
||||
)
|
||||
if res.status_code != HTTPStatus.OK:
|
||||
self.logger.warning(
|
||||
"redundant_expert: allgather_expert_token_stats fail. "
|
||||
+ f"addr {addr}, res {res.status_code} {res.json()}"
|
||||
)
|
||||
break
|
||||
|
||||
for meta_data in res.json()["data"]:
|
||||
expert_token_stats += np.array(meta_data, dtype=np.int32)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
self.logger.error(f"redundant_expert: allgather_expert_token_stats fail. addr {addr}, error {e}")
|
||||
if success_count == len(self.dp_rank_address):
|
||||
self.need_rearrange_expert = True
|
||||
self.model_tokens_per_expert_stats_list[:] = expert_token_stats[:]
|
||||
self.logger.info("redundant_expert: allgather_expert_token_stats success")
|
||||
return True
|
||||
self.logger.info(
|
||||
"redundant_expert: allgather_expert_token_stats fail. "
|
||||
+ f"succ {success_count} total {len(self.dp_rank_address)}"
|
||||
)
|
||||
return False
|
||||
|
||||
def broadcast_expert_token_stats(self):
|
||||
"""
|
||||
broadcast_expert_token_stats
|
||||
"""
|
||||
success_count = 0
|
||||
for addr in self.dp_rank_address:
|
||||
try:
|
||||
params = {
|
||||
"user": self.api_user,
|
||||
"passwd": self.api_passwd,
|
||||
"action": "recv_expert_weight",
|
||||
"data": self.model_tokens_per_expert_stats_list.tolist(),
|
||||
}
|
||||
res = requests.post(
|
||||
f"http://{addr}/rearrange_experts",
|
||||
json=params,
|
||||
timeout=self.http_timeout,
|
||||
)
|
||||
if res.status_code != HTTPStatus.OK:
|
||||
self.logger.warning(
|
||||
"redundant_expert: broadcast_expert_token_stats fail. "
|
||||
+ f"addr {addr}, res {res.status_code} {res.json()}"
|
||||
)
|
||||
break
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"redundant_expert: broadcast_expert_token_stats request fail. addr {addr}, error {e}"
|
||||
)
|
||||
if success_count == len(self.dp_rank_address):
|
||||
self.logger.info("redundant_expert: broadcast_expert_token_stats success")
|
||||
return True
|
||||
self.logger.info(
|
||||
"redundant_expert: broadcast_expert_token_stats failed, "
|
||||
+ f"succ {success_count} total {len(self.dp_rank_address)}"
|
||||
)
|
||||
return False
|
||||
|
||||
def allreduce_load_weight_result(self):
|
||||
"""
|
||||
权重加载结果
|
||||
"""
|
||||
if int(time.time()) - self.load_weight_begin_ts > self.load_weight_timeout:
|
||||
self.logger.info(f"redundant_expert: allreduce_load_weight_result timeout {self.load_weight_timeout}s")
|
||||
return True
|
||||
|
||||
all_success, exist_fail = self.allgather_load_weight_result()
|
||||
if exist_fail:
|
||||
# 如果有DP权重加载异常,结束本次重排
|
||||
self.logger.warning("redundant_expert: allreduce_load_weight_result exist fail, terminate this rearrange")
|
||||
return True
|
||||
if not all_success:
|
||||
self.logger.info("redundant_expert: allreduce_load_weight_result waiting")
|
||||
return False
|
||||
# self.broadcast_load_weight_success()
|
||||
if not exist_fail and all_success:
|
||||
# prefill需要等待调度屏蔽
|
||||
if (
|
||||
self.fd_config.splitwise_role == "decode"
|
||||
or not self.eplb_config.redundant_expert_enable_schedule_cordon
|
||||
):
|
||||
self.logger.info("redundant_expert: allreduce_load_weight_result success, notify infer.py")
|
||||
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
|
||||
signal_update_weight_from_tensor_array = IPCSignal(
|
||||
name="signal_update_weight_from_tensor",
|
||||
array=signal_update_weight_from_tensor,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_signal_suffix,
|
||||
create=False,
|
||||
)
|
||||
signal_update_weight_from_tensor_array.value[0] = 1
|
||||
return True
|
||||
|
||||
def allgather_load_weight_result(self):
|
||||
"""
|
||||
allgather_load_weight_result
|
||||
"""
|
||||
all_success, exist_fail = False, False
|
||||
|
||||
success_count, fail_count = 0, 0
|
||||
for addr in self.dp_rank_address:
|
||||
try:
|
||||
params = {
|
||||
"user": self.api_user,
|
||||
"passwd": self.api_passwd,
|
||||
"action": "check_load_weight_result",
|
||||
}
|
||||
res = requests.post(
|
||||
f"http://{addr}/check_redundant",
|
||||
json=params,
|
||||
timeout=self.http_timeout,
|
||||
)
|
||||
if res.status_code != HTTPStatus.OK:
|
||||
self.logger.warning(
|
||||
"redundant_expert: allgather_load_weight_result fail. "
|
||||
+ f"addr {addr}, res {res.status_code} {res.json()}"
|
||||
)
|
||||
break
|
||||
result_list = res.json()["data"]
|
||||
self.logger.info(
|
||||
f"redundant_expert: allgather_load_weight_result success. addr {addr}, result_list {result_list}"
|
||||
)
|
||||
for result in result_list:
|
||||
if result == 1:
|
||||
success_count += 1
|
||||
elif result == -1:
|
||||
fail_count += 1
|
||||
self.logger.error(
|
||||
f"redundant_expert: allgather_load_weight_result fail. addr {addr}, result {result}"
|
||||
)
|
||||
exist_fail = True
|
||||
except Exception as e:
|
||||
self.logger.error(f"redundant_expert: allgather_load_weight_result error. addr {addr}, error {e}")
|
||||
|
||||
if fail_count > 0:
|
||||
self.logger.info(
|
||||
"redundant_expert: allgather_load_weight_result not all ready, "
|
||||
+ f"succ {success_count} fail {fail_count} total {len(self.dp_rank_address)}"
|
||||
)
|
||||
else:
|
||||
self.logger.info("redundant_expert: allgather_load_weight_result all success")
|
||||
all_success = True
|
||||
return all_success, exist_fail
|
||||
160
fastdeploy/eplb/utils.py
Normal file
160
fastdeploy/eplb/utils.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
|
||||
|
||||
class RedundantExpertWorkload:
|
||||
"""Redundant Expert Workload"""
|
||||
|
||||
def __init__(self, redundant_expert_meta_dir="/tmp/redundant_expert_meta"):
|
||||
self.update_timestamp = time.time()
|
||||
self.tokens_per_expert_stats_list = None
|
||||
self.ep_rank_to_expert_id_list = None
|
||||
self.expert_id_to_ep_rank_array = None
|
||||
self.expert_in_rank_num_list = None
|
||||
self.cost_milliseconds = 0
|
||||
self.meta_file_name = f"{redundant_expert_meta_dir}/rearrange-experts.json"
|
||||
if not os.path.exists(redundant_expert_meta_dir):
|
||||
os.makedirs(redundant_expert_meta_dir, exist_ok=True)
|
||||
|
||||
def __json__(self):
|
||||
return self.__dict__
|
||||
|
||||
def dump(self):
|
||||
"""Dump the object to a JSON file."""
|
||||
begin = time.time()
|
||||
try:
|
||||
with open(self.meta_file_name, "w") as fout:
|
||||
json.dump(self.__dict__, fout)
|
||||
except Exception as e:
|
||||
return f"redundant_expert: dump expert workload failed, {e}"
|
||||
cost_time = int((time.time() - begin) * 1000 * 1000)
|
||||
return f"redundant_expert: dump expert workload result in {cost_time} us"
|
||||
|
||||
def load(self):
|
||||
"""Load the object from a JSON file."""
|
||||
if not os.path.exists(self.meta_file_name):
|
||||
return {}, f"redundant_expert: file {self.meta_file_name} is not exists"
|
||||
try:
|
||||
with open(self.meta_file_name, "r") as fin:
|
||||
meta = json.load(fin)
|
||||
self.__dict__.update(meta)
|
||||
return self.__json__(), "ok"
|
||||
except Exception as e:
|
||||
return {}, f"redundant_expert: load file {self.meta_file_name} failed, {e}"
|
||||
|
||||
|
||||
def init_eplb_signals(config: FDConfig, ipc_signal_suffix):
|
||||
"""
|
||||
Initialize shared memory to indicate eplb status
|
||||
"""
|
||||
if config.parallel_config.local_data_parallel_id == 0:
|
||||
# rearrange_experts_status Record the expert's rearrangement status
|
||||
rearrange_experts_array = np.zeros([1], dtype=np.int32)
|
||||
_ = IPCSignal(
|
||||
name="rearrange_experts_status",
|
||||
array=rearrange_experts_array,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
# Record all DP rank IPs when receiving expert rearrangement requests
|
||||
rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32)
|
||||
_ = IPCSignal(
|
||||
name="rearrange_experts_ips_size",
|
||||
array=rearrange_experts_ips_size_array,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
_ = IPCSignal(
|
||||
name="rearrange_experts_ips_list",
|
||||
shm_size=config.eplb_config.redundant_expert_ip_shm_size,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
# Receive signals for updating weights
|
||||
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
|
||||
_ = IPCSignal(
|
||||
name="signal_update_weight_from_tensor",
|
||||
array=signal_update_weight_from_tensor,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
# Record expert workload
|
||||
experts_token_stats = np.zeros(
|
||||
(config.model_config.num_hidden_layers, config.model_config.moe_num_experts),
|
||||
dtype=np.int32,
|
||||
)
|
||||
_ = IPCSignal(
|
||||
name="all_experts_token_stats",
|
||||
array=experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
_ = IPCSignal(
|
||||
name="local_experts_token_stats",
|
||||
array=experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
# Receive signals for loading weights
|
||||
signal_update_weight_from_disk = np.zeros([1], dtype=np.int32)
|
||||
_ = IPCSignal(
|
||||
name="signal_update_weight_from_disk",
|
||||
array=signal_update_weight_from_disk,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
# Receive signals for clearing expert loads
|
||||
clear_experts_token_stats = np.zeros([1], dtype=np.int32)
|
||||
_ = IPCSignal(
|
||||
name="signal_clear_experts_token_stats",
|
||||
array=clear_experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
result_update_weight_from_disk = np.zeros([1], dtype=np.int32)
|
||||
_ = IPCSignal(
|
||||
name="result_update_weight_from_disk",
|
||||
array=result_update_weight_from_disk,
|
||||
dtype=np.int32,
|
||||
suffix=ipc_signal_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(RedundantExpertWorkload("/tmp").load())
|
||||
@@ -17,6 +17,7 @@
|
||||
from .engine_cache_queue import EngineCacheQueue
|
||||
from .engine_worker_queue import EngineWorkerQueue
|
||||
from .ipc_signal import IPCSignal, shared_memory_exists
|
||||
from .ipc_signal_const import RearrangeExpertStatus
|
||||
from .zmq_client import ZmqIpcClient
|
||||
from .zmq_server import ZmqIpcServer, ZmqTcpServer
|
||||
|
||||
@@ -28,4 +29,5 @@ __all__ = [
|
||||
"ZmqTcpServer",
|
||||
"ZmqIpcServer",
|
||||
"shared_memory_exists",
|
||||
"RearrangeExpertStatus",
|
||||
]
|
||||
|
||||
@@ -55,10 +55,11 @@ class IPCSignal:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
array: np.ndarray,
|
||||
dtype: np.dtype,
|
||||
array: np.ndarray = None,
|
||||
dtype: np.dtype = None,
|
||||
suffix: int = None,
|
||||
create: bool = True,
|
||||
shm_size: int = None,
|
||||
) -> None:
|
||||
"""Initialize or connect to a shared memory block.
|
||||
|
||||
@@ -72,23 +73,36 @@ class IPCSignal:
|
||||
Raises:
|
||||
AssertionError: If create=True but memory already exists, or dtype mismatch.
|
||||
"""
|
||||
assert isinstance(array, np.ndarray), "Input must be a numpy array"
|
||||
assert dtype == array.dtype, "Specified dtype must match array dtype"
|
||||
if dtype is None or array is None:
|
||||
assert shm_size is not None, "shm_size must be specified if array and dtype are None"
|
||||
|
||||
# Set a suffix for name to avoid name conflict while there are multiple engine launched
|
||||
if suffix is not None:
|
||||
name = name + f".{suffix}"
|
||||
|
||||
if create:
|
||||
if shared_memory_exists(name):
|
||||
llm_logger.warning(f"ShareMemory: {name} already exists, delete it")
|
||||
SharedMemory(name=name, create=False).unlink()
|
||||
self.shm = SharedMemory(create=True, size=array.nbytes, name=name)
|
||||
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
|
||||
self.value[:] = array # Initialize with input array data
|
||||
if create:
|
||||
llm_logger.debug(f"creating ipc signal: {name}")
|
||||
if shared_memory_exists(name):
|
||||
llm_logger.warning(f"ShareMemory: {name} already exists, delete it")
|
||||
SharedMemory(name=name, create=False).unlink()
|
||||
self.shm = SharedMemory(create=True, size=shm_size, name=name)
|
||||
else:
|
||||
llm_logger.debug(f"attaching ipc signal: {name}")
|
||||
self.shm = SharedMemory(name=name)
|
||||
else:
|
||||
self.shm = SharedMemory(name=name)
|
||||
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
|
||||
assert isinstance(array, np.ndarray), "Input must be a numpy array"
|
||||
assert dtype == array.dtype, "Specified dtype must match array dtype"
|
||||
|
||||
# Set a suffix for name to avoid name conflict while there are multiple engine launched
|
||||
if suffix is not None:
|
||||
name = name + f".{suffix}"
|
||||
|
||||
if create:
|
||||
if shared_memory_exists(name):
|
||||
llm_logger.warning(f"ShareMemory: {name} already exists, delete it")
|
||||
SharedMemory(name=name, create=False).unlink()
|
||||
self.shm = SharedMemory(create=True, size=array.nbytes, name=name)
|
||||
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
|
||||
self.value[:] = array # Initialize with input array data
|
||||
else:
|
||||
self.shm = SharedMemory(name=name)
|
||||
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Release system resources and unlink the shared memory block."""
|
||||
|
||||
26
fastdeploy/inter_communicator/ipc_signal_const.py
Normal file
26
fastdeploy/inter_communicator/ipc_signal_const.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@dataclass
|
||||
class RearrangeExpertStatus(Enum):
|
||||
FREE = 0
|
||||
DOING = 1
|
||||
LOAD_SUCC = 2 # load weight from disk success
|
||||
DONE = 3
|
||||
@@ -388,7 +388,7 @@ class Ernie4_5_Model(nn.Layer):
|
||||
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
||||
self.fd_config = fd_config
|
||||
self.redundant_table_manger = None
|
||||
if fd_config.model_config.enable_redundant_experts is True:
|
||||
if fd_config.eplb_config.enable_eplb is True:
|
||||
self.redundant_table_manger = RedundantExpertManger(
|
||||
n_routed_experts=fd_config.model_config.moe_num_experts,
|
||||
num_hidden_layers=fd_config.model_config.num_hidden_layers,
|
||||
|
||||
@@ -66,6 +66,7 @@ class RolloutModelConfig:
|
||||
num_nextn_predict_layers: int = 0,
|
||||
enable_attention_dp_balance: bool = False,
|
||||
attention_dp_time_out_iters: int = 0,
|
||||
eplb_config: str = {},
|
||||
):
|
||||
# Required parameters
|
||||
self.model = model_name_or_path
|
||||
@@ -115,6 +116,7 @@ class RolloutModelConfig:
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.enable_attention_dp_balance = enable_attention_dp_balance
|
||||
self.attention_dp_time_out_iters = attention_dp_time_out_iters
|
||||
self.eplb_config = eplb_config
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
||||
@@ -30,6 +30,7 @@ from fastdeploy.config import (
|
||||
DecodingConfig,
|
||||
DeviceConfig,
|
||||
EarlyStopConfig,
|
||||
EPLBConfig,
|
||||
ErnieArchitectures,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
@@ -40,9 +41,16 @@ from fastdeploy.config import (
|
||||
SpeculativeConfig,
|
||||
)
|
||||
from fastdeploy.engine.request import RequestType
|
||||
from fastdeploy.eplb.async_expert_loader import (
|
||||
MODEL_MAIN_NAME,
|
||||
REARRANGE_EXPERT_MAGIC_NUM,
|
||||
create_mmap,
|
||||
load_tensor_from_shm_mem,
|
||||
)
|
||||
from fastdeploy.eplb.experts_manager import RedundantExpertManager
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.inter_communicator import IPCSignal, RearrangeExpertStatus
|
||||
from fastdeploy.model_executor.layers.quantization import get_quantization_config
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import get_logger, parse_quantization
|
||||
@@ -151,6 +159,7 @@ class PaddleDisWorkerProc:
|
||||
self.fd_config = fd_config
|
||||
self.parallel_config = fd_config.parallel_config
|
||||
self.cache_config = fd_config.cache_config
|
||||
self.eplb_config = fd_config.eplb_config
|
||||
|
||||
# TODO(gongshaotian): Use worker factory to get worker
|
||||
self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
|
||||
@@ -249,6 +258,18 @@ class PaddleDisWorkerProc:
|
||||
create=False,
|
||||
)
|
||||
|
||||
def update_weights_from_tensor(self, mmap_infos):
|
||||
"""
|
||||
update_weights_from_tensor
|
||||
"""
|
||||
state_dicts = load_tensor_from_shm_mem(self.experts_manager.tensor_infos, mmap_infos[MODEL_MAIN_NAME], logger)
|
||||
rank_expert_list, logical_to_physical_map, expert_count = self.experts_manager.get_ep_rank_to_expert_id_list()
|
||||
self.worker.get_model().redundant_table_manger.update_expert_rank_table(
|
||||
rank_expert_list, logical_to_physical_map, expert_count
|
||||
)
|
||||
# TO BE FIXED
|
||||
self.worker.get_model().update_state_dict(state_dicts)
|
||||
|
||||
def _broadcast_model_weights_signal(self, src: int, group) -> int:
|
||||
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
|
||||
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
|
||||
@@ -258,6 +279,63 @@ class PaddleDisWorkerProc:
|
||||
"""Main event loop for Paddle Distrubuted Workers.
|
||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||
"""
|
||||
if self.eplb_config.enable_eplb:
|
||||
self.last_dump_expert_workload_ts = 0
|
||||
self.experts_manager = RedundantExpertManager(
|
||||
rank=self.local_rank,
|
||||
ep_size=self.ranks,
|
||||
fd_config=self.fd_config,
|
||||
ipc_signal_suffix=self.parallel_config.engine_worker_queue_port,
|
||||
)
|
||||
experts_token_stats = np.zeros(
|
||||
(self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.moe_num_experts),
|
||||
dtype=np.int32,
|
||||
)
|
||||
local_experts_token_stats_array = IPCSignal(
|
||||
name="local_experts_token_stats",
|
||||
array=experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
clear_experts_token_stats = np.zeros([1], dtype=np.int32)
|
||||
signal_clear_experts_token_stats = IPCSignal(
|
||||
name="signal_clear_experts_token_stats",
|
||||
array=clear_experts_token_stats,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
if self.local_rank == 0:
|
||||
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
|
||||
signal_update_weight_from_tensor_array = IPCSignal(
|
||||
name="signal_update_weight_from_tensor",
|
||||
array=signal_update_weight_from_tensor,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
rearrange_experts_status = np.zeros([1], dtype=np.int32)
|
||||
rearrange_experts_signal = IPCSignal(
|
||||
name="rearrange_experts_status",
|
||||
array=rearrange_experts_status,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
mmap_infos = create_mmap(
|
||||
[MODEL_MAIN_NAME],
|
||||
self.local_rank,
|
||||
self.ranks,
|
||||
shm_uuid=self.parallel_config.engine_worker_queue_port,
|
||||
eplb_config=self.eplb_config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Currently, only support single node
|
||||
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
|
||||
req_ids = []
|
||||
@@ -267,6 +345,45 @@ class PaddleDisWorkerProc:
|
||||
attention_dp_cached_prefill_tasks = []
|
||||
attention_dp_wait_prefill_iters = 0
|
||||
while True:
|
||||
if self.eplb_config.enable_eplb:
|
||||
rearrange_time = time.time()
|
||||
# 获取专家负载
|
||||
if local_experts_token_stats_array.value is not None and (
|
||||
int(rearrange_time) - self.last_dump_expert_workload_ts
|
||||
> self.eplb_config.redundant_expert_dump_workload_interval
|
||||
):
|
||||
self.last_dump_expert_workload_ts = int(rearrange_time)
|
||||
clear_stat = False
|
||||
if signal_clear_experts_token_stats.value[0] == 1:
|
||||
clear_stat = True
|
||||
signal_clear_experts_token_stats.value[0] = 0
|
||||
(
|
||||
new_stats_array,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self.worker.get_model().redundant_table_manger.get_expert_tokens_stats(clear_stat=clear_stat)
|
||||
local_experts_token_stats_array.value[:] = new_stats_array[:]
|
||||
elif local_experts_token_stats_array.value is None:
|
||||
logger.warning("redundant_expert: local_experts_token_stats not init")
|
||||
|
||||
# 所有DP同步更新权重
|
||||
broadcast_value = 0
|
||||
if self.local_rank == 0 and signal_update_weight_from_tensor_array.value[0] == 1:
|
||||
logger.info("redundant_expert: update_weight_from_tensor broadcast signal")
|
||||
signal_update_weight_from_tensor_array.value[0] = 0
|
||||
broadcast_value = REARRANGE_EXPERT_MAGIC_NUM
|
||||
data = paddle.to_tensor([broadcast_value])
|
||||
paddle.distributed.broadcast(data, 0)
|
||||
if data[0] == REARRANGE_EXPERT_MAGIC_NUM:
|
||||
self.update_weights_from_tensor(mmap_infos)
|
||||
logger.info(
|
||||
f"redundant_expert: update_weight_from_tensor success, cost {(time.time() - rearrange_time)*1000}ms"
|
||||
)
|
||||
paddle.distributed.barrier()
|
||||
if self.local_rank == 0:
|
||||
rearrange_experts_signal.value[0] = RearrangeExpertStatus.DONE.value
|
||||
logger.info("redundant_expert: done")
|
||||
if local_rank == 0:
|
||||
if self.model_weights_status.value[0] != 0:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
@@ -706,6 +823,13 @@ def parse_args():
|
||||
help="max waiting steps to sync all dp for prefill tasks available",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--eplb_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="EPLB Configuration.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -764,6 +888,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
early_stop_config = EarlyStopConfig(args.early_stop_config)
|
||||
|
||||
eplb_config = EPLBConfig(args.eplb_config)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
|
||||
@@ -861,6 +987,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
moba_attention_config=moba_attention_config,
|
||||
enable_attention_dp_balance=args.enable_attention_dp_balance,
|
||||
attention_dp_time_out_iters=args.attention_dp_time_out_iters,
|
||||
eplb_config=eplb_config,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
|
||||
|
||||
@@ -40,3 +40,5 @@ opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
||||
partial_json_parser
|
||||
einops
|
||||
cuda-python==12.8
|
||||
setproctitle
|
||||
|
||||
Reference in New Issue
Block a user