[Feature] Support_eplb (#2997)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [Feature] support_eplb

* [Feature] support_eplb

* [Fix] fix mm ep
This commit is contained in:
xiaoxiaohehe001
2025-07-24 20:22:45 +08:00
committed by GitHub
parent f37d00e856
commit 2970b00dfa
11 changed files with 118 additions and 50 deletions

View File

@@ -765,7 +765,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* moe/fused_moe/moe_redundant_topk_select.cu
* moe_redundant_topk_select
*/
m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel,
m.def("moe_redundant_topk_select", &MoERedundantTopKSelectKernel,
py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"),
py::arg("expert_in_rank_num_list"),
py::arg("tokens_per_expert_stats_list"), py::arg("bias"),

View File

@@ -254,7 +254,7 @@ std::vector<paddle::DataType> MoERedundantTopKSelectKernelInferDtype(
}
PD_BUILD_OP(moe_redundant_topk_select)
PD_BUILD_STATIC_OP(moe_redundant_topk_select)
.Inputs({"gating_logits", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list", paddle::Optional("bias")})
.Outputs({"topk_ids",
"topk_weights",

View File

@@ -106,6 +106,8 @@ class ModelConfig:
self.dtype = ""
self.enable_logprob = False
self.enable_mm = False
self.enable_redundant_experts = False
self.redundant_experts_num = 0
for key, value in args.items():
if hasattr(self, key):

View File

@@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,

View File

@@ -77,7 +77,7 @@ class DeepEPEngine:
elif moe_phase == MoEPhase.PREFILL:
self.deepep_engine = deep_ep.Buffer(
self.group,
int(1e9),
int(5e8),
0,
low_latency_mode=False,
num_qps_per_rank=1,
@@ -214,13 +214,15 @@ class EPRunner:
num_max_dispatch_tokens_per_rank: int = 1,
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
):
self.top_k = top_k
self.num_experts = num_experts
self.redundant_experts_num = redundant_experts_num
self.ep_engine = DeepEPEngine(
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
hidden=hidden,
num_experts=num_experts,
num_experts=num_experts + redundant_experts_num,
moe_phase=moe_phase,
ep_size=ep_size,
ep_rank=ep_rank,
@@ -230,13 +232,33 @@ class EPRunner:
"""
moe_select
"""
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
self.top_k,
True, # apply_norm_weight,
False,
)
if layer.redundant_table_manger is not None:
(
ep_rank_to_expert_id_list,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx)
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
gating_logits=gate_out,
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
expert_in_rank_num_list=expert_in_rank_num_list,
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
bias=layer.gate_correction_bias,
moe_topk=self.top_k,
apply_norm_weight=True, # apply_norm_weight
enable_softmax_top_k_fused=False,
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
)
else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
self.top_k,
True, # apply_norm_weight,
False,
)
return topk_idx, topk_weights
@abstractmethod
@@ -266,6 +288,7 @@ class EPPrefillRunner(EPRunner):
num_experts: int,
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
):
super().__init__(
top_k,
@@ -274,6 +297,7 @@ class EPPrefillRunner(EPRunner):
MoEPhase.PREFILL,
ep_size=ep_size,
ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num,
)
def dispatch(
@@ -336,6 +360,7 @@ class EPDecoderRunner(EPRunner):
num_max_dispatch_tokens_per_rank: int,
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
):
super().__init__(
top_k,
@@ -345,6 +370,7 @@ class EPDecoderRunner(EPRunner):
num_max_dispatch_tokens_per_rank,
ep_size=ep_size,
ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num,
)
def dispatch(

View File

@@ -55,6 +55,7 @@ class MoEMethodBase(QuantMethodBase):
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
)
else:
from .ep import EPPrefillRunner
@@ -65,6 +66,7 @@ class MoEMethodBase(QuantMethodBase):
layer.num_experts,
layer.ep_size,
layer.ep_rank,
layer.fd_config.model_config.redundant_experts_num,
)
def process_loaded_weights(self, layer, weights) -> None:

View File

@@ -436,7 +436,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
@@ -444,8 +444,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for i in range(layer.num_local_experts):
expert_idx = layer.expert_id_offset + i
for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
)

View File

@@ -71,7 +71,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
@@ -79,13 +79,25 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for i in range(layer.num_local_experts):
expert_idx = layer.expert_id_offset + i
for expert_idx in logical_expert_ids:
up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx)
down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx)
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
get_tensor(
state_dict.pop(up_gate_proj_expert_weight_scale_key_name)
if up_gate_proj_expert_weight_scale_key_name in state_dict
else up_gate_proj_expert_weight_scale_key_name,
layer.fd_config.parallel_config.model_name_or_path,
)
)
down_proj_weight_scale.append(
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
get_tensor(
state_dict.pop(down_proj_expert_weight_scale_key_name)
if down_proj_expert_weight_scale_key_name in state_dict
else down_proj_expert_weight_scale_key_name,
layer.fd_config.parallel_config.model_name_or_path,
)
)
up_gate_proj_weight = (

View File

@@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,

View File

@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.worker.experts_manager import RedundantExpertManger
def get_moe_method():
@@ -117,7 +118,15 @@ class FusedMoE(nn.Layer):
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
self.quant_method = get_moe_method()
self.redundant_table_manger = None
if self.ep_size > 1:
if fd_config.model_config.enable_redundant_experts is True:
self.redundant_table_manger = RedundantExpertManger(
n_routed_experts=fd_config.model_config.moe_num_experts,
num_hidden_layers=fd_config.model_config.num_hidden_layers,
redundant_experts_num=fd_config.model_config.redundant_experts_num,
ep_size=self.ep_size,
)
self.quant_method.init_ep(self)
if fd_config.load_config.dynamic_load_weight:
@@ -222,12 +231,28 @@ class FusedMoE(nn.Layer):
up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight.
down_proj_expert_weight_key (str): The key of down_proj expert weight.
"""
logical_expert_ids = [
i
for i in range(
self.expert_id_offset,
self.expert_id_offset + self.num_local_experts,
)
]
if self.redundant_table_manger is not None:
(
ep_rank_to_expert_id_list,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
) = self.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(self.layer_idx)
logical_expert_ids = ep_rank_to_expert_id_list[
self.expert_id_offset : self.expert_id_offset + self.num_local_experts
]
up_gate_proj_weights = []
down_proj_weights = []
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
if is_ffn_merged:
for i in range(self.num_local_experts):
expert_idx = self.expert_id_offset + i
for expert_idx in logical_expert_ids:
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
up_gate_proj_weights.append(
@@ -253,8 +278,7 @@ class FusedMoE(nn.Layer):
else:
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj")
up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj")
for j in range(self.num_local_experts):
expert_idx = self.expert_id_offset + j
for expert_idx in logical_expert_ids:
gate_expert_weight_key_name = gate_expert_weight_key.format(expert_idx)
up_expert_weight_key_name = up_expert_weight_key.format(expert_idx)
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
@@ -285,7 +309,7 @@ class FusedMoE(nn.Layer):
self.fd_config.parallel_config.model_name_or_path,
)
)
return up_gate_proj_weights, down_proj_weights
return up_gate_proj_weights, down_proj_weights, logical_expert_ids
def extract_moe_ffn_weights(self, state_dict: dict):
"""
@@ -308,7 +332,7 @@ class FusedMoE(nn.Layer):
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
up_gate_proj_weights, down_proj_weights = self.load_experts_weight(
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
@@ -329,33 +353,36 @@ class FusedMoE(nn.Layer):
gate_correction_bias_tensor = get_tensor(state_dict.pop(gate_correction_bias_key)).astype("float32")
return gate_correction_bias_tensor
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict, is_rearrange: bool = False):
"""
load_state_dict function.
"""
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
self.moe_use_gate_correction_bias = True
else:
self.moe_use_gate_correction_bias = False
if self.moe_use_gate_correction_bias:
gate_correction_bias_tensor = self.extract_gate_correction_bias(self.gate_correction_bias_key, state_dict)
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_tensor.shape,
if not is_rearrange:
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
self.moe_use_gate_correction_bias = True
else:
self.moe_use_gate_correction_bias = False
if self.moe_use_gate_correction_bias:
gate_correction_bias_tensor = self.extract_gate_correction_bias(
self.gate_correction_bias_key, state_dict
)
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_tensor.shape,
dtype="float32",
)
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
self.gate_weight = self.create_parameter(
shape=gate_weight_tensor.shape,
dtype="float32",
)
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
self.gate_weight = self.create_parameter(
shape=gate_weight_tensor.shape,
dtype="float32",
)
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
if self.fd_config.model_config.is_quantized:
self.quant_method.process_prequanted_weights(self, state_dict)

View File

@@ -37,7 +37,7 @@ class RedundantExpertManger:
ep_size: int,
) -> None:
"""Initialize a redundant expert manager"""
self.num_expert = n_routed_experts
self.num_expert = n_routed_experts if isinstance(n_routed_experts, int) else n_routed_experts[0]
self.redundant_experts_num = redundant_experts_num
self.num_hidden_layers = num_hidden_layers