diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index d74a77121..d344fe9ee 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -765,7 +765,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * moe/fused_moe/moe_redundant_topk_select.cu * moe_redundant_topk_select */ - m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel, + m.def("moe_redundant_topk_select", &MoERedundantTopKSelectKernel, py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"), py::arg("expert_in_rank_num_list"), py::arg("tokens_per_expert_stats_list"), py::arg("bias"), diff --git a/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu b/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu index a53cb0a95..0a7b5ac6a 100644 --- a/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu +++ b/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu @@ -254,7 +254,7 @@ std::vector MoERedundantTopKSelectKernelInferDtype( } -PD_BUILD_OP(moe_redundant_topk_select) +PD_BUILD_STATIC_OP(moe_redundant_topk_select) .Inputs({"gating_logits", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list", paddle::Optional("bias")}) .Outputs({"topk_ids", "topk_weights", diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 869cea33f..89efeee6f 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -106,6 +106,8 @@ class ModelConfig: self.dtype = "" self.enable_logprob = False self.enable_mm = False + self.enable_redundant_experts = False + self.redundant_experts_num = 0 for key, value in args.items(): if hasattr(self, key): diff --git a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py index 89c0efc37..c20064c69 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py @@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None) down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) - up_gate_proj_weights, down_proj_weights = layer.load_experts_weight( + up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight( state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index acc070309..d7463a0f9 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -77,7 +77,7 @@ class DeepEPEngine: elif moe_phase == MoEPhase.PREFILL: self.deepep_engine = deep_ep.Buffer( self.group, - int(1e9), + int(5e8), 0, low_latency_mode=False, num_qps_per_rank=1, @@ -214,13 +214,15 @@ class EPRunner: num_max_dispatch_tokens_per_rank: int = 1, ep_size: int = 1, ep_rank: int = 0, + redundant_experts_num: int = 0, ): self.top_k = top_k self.num_experts = num_experts + self.redundant_experts_num = redundant_experts_num self.ep_engine = DeepEPEngine( num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, hidden=hidden, - num_experts=num_experts, + num_experts=num_experts + redundant_experts_num, moe_phase=moe_phase, ep_size=ep_size, ep_rank=ep_rank, @@ -230,13 +232,33 @@ class EPRunner: """ moe_select """ - topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - self.top_k, - True, # apply_norm_weight, - False, - ) + if layer.redundant_table_manger is not None: + ( + ep_rank_to_expert_id_list, + expert_id_to_ep_rank_array, + expert_in_rank_num_list, + tokens_per_expert_stats_list, + ) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx) + + topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( + gating_logits=gate_out, + expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, + expert_in_rank_num_list=expert_in_rank_num_list, + tokens_per_expert_stats_list=tokens_per_expert_stats_list, + bias=layer.gate_correction_bias, + moe_topk=self.top_k, + apply_norm_weight=True, # apply_norm_weight + enable_softmax_top_k_fused=False, + redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + ) + else: + topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + self.top_k, + True, # apply_norm_weight, + False, + ) return topk_idx, topk_weights @abstractmethod @@ -266,6 +288,7 @@ class EPPrefillRunner(EPRunner): num_experts: int, ep_size: int = 1, ep_rank: int = 0, + redundant_experts_num: int = 0, ): super().__init__( top_k, @@ -274,6 +297,7 @@ class EPPrefillRunner(EPRunner): MoEPhase.PREFILL, ep_size=ep_size, ep_rank=ep_rank, + redundant_experts_num=redundant_experts_num, ) def dispatch( @@ -336,6 +360,7 @@ class EPDecoderRunner(EPRunner): num_max_dispatch_tokens_per_rank: int, ep_size: int = 1, ep_rank: int = 0, + redundant_experts_num: int = 0, ): super().__init__( top_k, @@ -345,6 +370,7 @@ class EPDecoderRunner(EPRunner): num_max_dispatch_tokens_per_rank, ep_size=ep_size, ep_rank=ep_rank, + redundant_experts_num=redundant_experts_num, ) def dispatch( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 0f65f45d8..ad46d00c0 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -55,6 +55,7 @@ class MoEMethodBase(QuantMethodBase): layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, layer.ep_size, layer.ep_rank, + layer.fd_config.model_config.redundant_experts_num, ) else: from .ep import EPPrefillRunner @@ -65,6 +66,7 @@ class MoEMethodBase(QuantMethodBase): layer.num_experts, layer.ep_size, layer.ep_rank, + layer.fd_config.model_config.redundant_experts_num, ) def process_loaded_weights(self, layer, weights) -> None: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index f4048fc09..c13b971ea 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -436,7 +436,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None) down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) - up_gate_proj_weights, down_proj_weights = layer.load_experts_weight( + up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight( state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, @@ -444,8 +444,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): # self.check(layer, up_gate_proj_weights, down_proj_weights) up_gate_proj_weight_scale = [] down_proj_weight_scale = [] - for i in range(layer.num_local_experts): - expert_idx = layer.expert_id_offset + i + for expert_idx in logical_expert_ids: up_gate_proj_weight_scale.append( get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))) ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index f7259645b..d53e62028 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -71,7 +71,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None) down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) - up_gate_proj_weights, down_proj_weights = layer.load_experts_weight( + up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight( state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, @@ -79,13 +79,25 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): # self.check(layer, up_gate_proj_weights, down_proj_weights) up_gate_proj_weight_scale = [] down_proj_weight_scale = [] - for i in range(layer.num_local_experts): - expert_idx = layer.expert_id_offset + i + for expert_idx in logical_expert_ids: + up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx) + down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx) + up_gate_proj_weight_scale.append( - get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))) + get_tensor( + state_dict.pop(up_gate_proj_expert_weight_scale_key_name) + if up_gate_proj_expert_weight_scale_key_name in state_dict + else up_gate_proj_expert_weight_scale_key_name, + layer.fd_config.parallel_config.model_name_or_path, + ) ) down_proj_weight_scale.append( - get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))) + get_tensor( + state_dict.pop(down_proj_expert_weight_scale_key_name) + if down_proj_expert_weight_scale_key_name in state_dict + else down_proj_expert_weight_scale_key_name, + layer.fd_config.parallel_config.model_name_or_path, + ) ) up_gate_proj_weight = ( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index cc2932d4e..041283bbd 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None) down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None) - up_gate_proj_weights, down_proj_weights = layer.load_experts_weight( + up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight( state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index a1d689961..574df2159 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -20,6 +20,7 @@ from paddleformers.utils.log import logger from fastdeploy import envs from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.worker.experts_manager import RedundantExpertManger def get_moe_method(): @@ -117,7 +118,15 @@ class FusedMoE(nn.Layer): # now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future self.quant_method = get_moe_method() + self.redundant_table_manger = None if self.ep_size > 1: + if fd_config.model_config.enable_redundant_experts is True: + self.redundant_table_manger = RedundantExpertManger( + n_routed_experts=fd_config.model_config.moe_num_experts, + num_hidden_layers=fd_config.model_config.num_hidden_layers, + redundant_experts_num=fd_config.model_config.redundant_experts_num, + ep_size=self.ep_size, + ) self.quant_method.init_ep(self) if fd_config.load_config.dynamic_load_weight: @@ -222,12 +231,28 @@ class FusedMoE(nn.Layer): up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight. down_proj_expert_weight_key (str): The key of down_proj expert weight. """ + logical_expert_ids = [ + i + for i in range( + self.expert_id_offset, + self.expert_id_offset + self.num_local_experts, + ) + ] + if self.redundant_table_manger is not None: + ( + ep_rank_to_expert_id_list, + expert_id_to_ep_rank_array, + expert_in_rank_num_list, + tokens_per_expert_stats_list, + ) = self.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(self.layer_idx) + logical_expert_ids = ep_rank_to_expert_id_list[ + self.expert_id_offset : self.expert_id_offset + self.num_local_experts + ] up_gate_proj_weights = [] down_proj_weights = [] is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict if is_ffn_merged: - for i in range(self.num_local_experts): - expert_idx = self.expert_id_offset + i + for expert_idx in logical_expert_ids: down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx) up_gate_proj_weights.append( @@ -253,8 +278,7 @@ class FusedMoE(nn.Layer): else: gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj") up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj") - for j in range(self.num_local_experts): - expert_idx = self.expert_id_offset + j + for expert_idx in logical_expert_ids: gate_expert_weight_key_name = gate_expert_weight_key.format(expert_idx) up_expert_weight_key_name = up_expert_weight_key.format(expert_idx) down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) @@ -285,7 +309,7 @@ class FusedMoE(nn.Layer): self.fd_config.parallel_config.model_name_or_path, ) ) - return up_gate_proj_weights, down_proj_weights + return up_gate_proj_weights, down_proj_weights, logical_expert_ids def extract_moe_ffn_weights(self, state_dict: dict): """ @@ -308,7 +332,7 @@ class FusedMoE(nn.Layer): assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none." assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none." - up_gate_proj_weights, down_proj_weights = self.load_experts_weight( + up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight( state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, @@ -329,33 +353,36 @@ class FusedMoE(nn.Layer): gate_correction_bias_tensor = get_tensor(state_dict.pop(gate_correction_bias_key)).astype("float32") return gate_correction_bias_tensor - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict, is_rearrange: bool = False): """ load_state_dict function. """ - self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None) - if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict: - self.moe_use_gate_correction_bias = True - else: - self.moe_use_gate_correction_bias = False - if self.moe_use_gate_correction_bias: - gate_correction_bias_tensor = self.extract_gate_correction_bias(self.gate_correction_bias_key, state_dict) - self.gate_correction_bias = self.create_parameter( - shape=gate_correction_bias_tensor.shape, + if not is_rearrange: + self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None) + if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict: + self.moe_use_gate_correction_bias = True + else: + self.moe_use_gate_correction_bias = False + if self.moe_use_gate_correction_bias: + gate_correction_bias_tensor = self.extract_gate_correction_bias( + self.gate_correction_bias_key, state_dict + ) + self.gate_correction_bias = self.create_parameter( + shape=gate_correction_bias_tensor.shape, + dtype="float32", + ) + self.gate_correction_bias.set_value(gate_correction_bias_tensor) + + gate_weight_key = self.weight_key_map.get("gate_weight_key", None) + assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints" + + gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key)) + + self.gate_weight = self.create_parameter( + shape=gate_weight_tensor.shape, dtype="float32", ) - self.gate_correction_bias.set_value(gate_correction_bias_tensor) - - gate_weight_key = self.weight_key_map.get("gate_weight_key", None) - assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints" - - gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key)) - - self.gate_weight = self.create_parameter( - shape=gate_weight_tensor.shape, - dtype="float32", - ) - self.gate_weight.set_value(gate_weight_tensor.astype("float32")) + self.gate_weight.set_value(gate_weight_tensor.astype("float32")) if self.fd_config.model_config.is_quantized: self.quant_method.process_prequanted_weights(self, state_dict) diff --git a/fastdeploy/worker/experts_manager.py b/fastdeploy/worker/experts_manager.py index bb86e4479..0e7fd726c 100644 --- a/fastdeploy/worker/experts_manager.py +++ b/fastdeploy/worker/experts_manager.py @@ -37,7 +37,7 @@ class RedundantExpertManger: ep_size: int, ) -> None: """Initialize a redundant expert manager""" - self.num_expert = n_routed_experts + self.num_expert = n_routed_experts if isinstance(n_routed_experts, int) else n_routed_experts[0] self.redundant_experts_num = redundant_experts_num self.num_hidden_layers = num_hidden_layers