mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support_eplb (#2997)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [Feature] support_eplb * [Feature] support_eplb * [Fix] fix mm ep
This commit is contained in:
@@ -765,7 +765,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
* moe/fused_moe/moe_redundant_topk_select.cu
|
* moe/fused_moe/moe_redundant_topk_select.cu
|
||||||
* moe_redundant_topk_select
|
* moe_redundant_topk_select
|
||||||
*/
|
*/
|
||||||
m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel,
|
m.def("moe_redundant_topk_select", &MoERedundantTopKSelectKernel,
|
||||||
py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"),
|
py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"),
|
||||||
py::arg("expert_in_rank_num_list"),
|
py::arg("expert_in_rank_num_list"),
|
||||||
py::arg("tokens_per_expert_stats_list"), py::arg("bias"),
|
py::arg("tokens_per_expert_stats_list"), py::arg("bias"),
|
||||||
|
@@ -254,7 +254,7 @@ std::vector<paddle::DataType> MoERedundantTopKSelectKernelInferDtype(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
PD_BUILD_OP(moe_redundant_topk_select)
|
PD_BUILD_STATIC_OP(moe_redundant_topk_select)
|
||||||
.Inputs({"gating_logits", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list", paddle::Optional("bias")})
|
.Inputs({"gating_logits", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list", paddle::Optional("bias")})
|
||||||
.Outputs({"topk_ids",
|
.Outputs({"topk_ids",
|
||||||
"topk_weights",
|
"topk_weights",
|
||||||
|
@@ -106,6 +106,8 @@ class ModelConfig:
|
|||||||
self.dtype = ""
|
self.dtype = ""
|
||||||
self.enable_logprob = False
|
self.enable_logprob = False
|
||||||
self.enable_mm = False
|
self.enable_mm = False
|
||||||
|
self.enable_redundant_experts = False
|
||||||
|
self.redundant_experts_num = 0
|
||||||
|
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
|
@@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
|||||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
|
@@ -77,7 +77,7 @@ class DeepEPEngine:
|
|||||||
elif moe_phase == MoEPhase.PREFILL:
|
elif moe_phase == MoEPhase.PREFILL:
|
||||||
self.deepep_engine = deep_ep.Buffer(
|
self.deepep_engine = deep_ep.Buffer(
|
||||||
self.group,
|
self.group,
|
||||||
int(1e9),
|
int(5e8),
|
||||||
0,
|
0,
|
||||||
low_latency_mode=False,
|
low_latency_mode=False,
|
||||||
num_qps_per_rank=1,
|
num_qps_per_rank=1,
|
||||||
@@ -214,13 +214,15 @@ class EPRunner:
|
|||||||
num_max_dispatch_tokens_per_rank: int = 1,
|
num_max_dispatch_tokens_per_rank: int = 1,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
|
redundant_experts_num: int = 0,
|
||||||
):
|
):
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
|
self.redundant_experts_num = redundant_experts_num
|
||||||
self.ep_engine = DeepEPEngine(
|
self.ep_engine = DeepEPEngine(
|
||||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||||
hidden=hidden,
|
hidden=hidden,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts + redundant_experts_num,
|
||||||
moe_phase=moe_phase,
|
moe_phase=moe_phase,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
@@ -230,13 +232,33 @@ class EPRunner:
|
|||||||
"""
|
"""
|
||||||
moe_select
|
moe_select
|
||||||
"""
|
"""
|
||||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
if layer.redundant_table_manger is not None:
|
||||||
gate_out,
|
(
|
||||||
layer.gate_correction_bias,
|
ep_rank_to_expert_id_list,
|
||||||
self.top_k,
|
expert_id_to_ep_rank_array,
|
||||||
True, # apply_norm_weight,
|
expert_in_rank_num_list,
|
||||||
False,
|
tokens_per_expert_stats_list,
|
||||||
)
|
) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx)
|
||||||
|
|
||||||
|
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
|
||||||
|
gating_logits=gate_out,
|
||||||
|
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
|
||||||
|
expert_in_rank_num_list=expert_in_rank_num_list,
|
||||||
|
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||||
|
bias=layer.gate_correction_bias,
|
||||||
|
moe_topk=self.top_k,
|
||||||
|
apply_norm_weight=True, # apply_norm_weight
|
||||||
|
enable_softmax_top_k_fused=False,
|
||||||
|
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
gate_out,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
self.top_k,
|
||||||
|
True, # apply_norm_weight,
|
||||||
|
False,
|
||||||
|
)
|
||||||
return topk_idx, topk_weights
|
return topk_idx, topk_weights
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -266,6 +288,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
num_experts: int,
|
num_experts: int,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
|
redundant_experts_num: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
top_k,
|
top_k,
|
||||||
@@ -274,6 +297,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
MoEPhase.PREFILL,
|
MoEPhase.PREFILL,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
|
redundant_experts_num=redundant_experts_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
@@ -336,6 +360,7 @@ class EPDecoderRunner(EPRunner):
|
|||||||
num_max_dispatch_tokens_per_rank: int,
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
|
redundant_experts_num: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
top_k,
|
top_k,
|
||||||
@@ -345,6 +370,7 @@ class EPDecoderRunner(EPRunner):
|
|||||||
num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
|
redundant_experts_num=redundant_experts_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
|
@@ -55,6 +55,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||||
layer.ep_size,
|
layer.ep_size,
|
||||||
layer.ep_rank,
|
layer.ep_rank,
|
||||||
|
layer.fd_config.model_config.redundant_experts_num,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from .ep import EPPrefillRunner
|
from .ep import EPPrefillRunner
|
||||||
@@ -65,6 +66,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
layer.num_experts,
|
layer.num_experts,
|
||||||
layer.ep_size,
|
layer.ep_size,
|
||||||
layer.ep_rank,
|
layer.ep_rank,
|
||||||
|
layer.fd_config.model_config.redundant_experts_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_loaded_weights(self, layer, weights) -> None:
|
def process_loaded_weights(self, layer, weights) -> None:
|
||||||
|
@@ -436,7 +436,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
@@ -444,8 +444,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
up_gate_proj_weight_scale = []
|
up_gate_proj_weight_scale = []
|
||||||
down_proj_weight_scale = []
|
down_proj_weight_scale = []
|
||||||
for i in range(layer.num_local_experts):
|
for expert_idx in logical_expert_ids:
|
||||||
expert_idx = layer.expert_id_offset + i
|
|
||||||
up_gate_proj_weight_scale.append(
|
up_gate_proj_weight_scale.append(
|
||||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
||||||
)
|
)
|
||||||
|
@@ -71,7 +71,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
@@ -79,13 +79,25 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
up_gate_proj_weight_scale = []
|
up_gate_proj_weight_scale = []
|
||||||
down_proj_weight_scale = []
|
down_proj_weight_scale = []
|
||||||
for i in range(layer.num_local_experts):
|
for expert_idx in logical_expert_ids:
|
||||||
expert_idx = layer.expert_id_offset + i
|
up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx)
|
||||||
|
down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx)
|
||||||
|
|
||||||
up_gate_proj_weight_scale.append(
|
up_gate_proj_weight_scale.append(
|
||||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
get_tensor(
|
||||||
|
state_dict.pop(up_gate_proj_expert_weight_scale_key_name)
|
||||||
|
if up_gate_proj_expert_weight_scale_key_name in state_dict
|
||||||
|
else up_gate_proj_expert_weight_scale_key_name,
|
||||||
|
layer.fd_config.parallel_config.model_name_or_path,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
down_proj_weight_scale.append(
|
down_proj_weight_scale.append(
|
||||||
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
get_tensor(
|
||||||
|
state_dict.pop(down_proj_expert_weight_scale_key_name)
|
||||||
|
if down_proj_expert_weight_scale_key_name in state_dict
|
||||||
|
else down_proj_expert_weight_scale_key_name,
|
||||||
|
layer.fd_config.parallel_config.model_name_or_path,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
up_gate_proj_weight = (
|
up_gate_proj_weight = (
|
||||||
|
@@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
|||||||
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
|
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
|
||||||
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
|
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
|
@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
|
|||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
|
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||||
|
|
||||||
|
|
||||||
def get_moe_method():
|
def get_moe_method():
|
||||||
@@ -117,7 +118,15 @@ class FusedMoE(nn.Layer):
|
|||||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||||
self.quant_method = get_moe_method()
|
self.quant_method = get_moe_method()
|
||||||
|
|
||||||
|
self.redundant_table_manger = None
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
|
if fd_config.model_config.enable_redundant_experts is True:
|
||||||
|
self.redundant_table_manger = RedundantExpertManger(
|
||||||
|
n_routed_experts=fd_config.model_config.moe_num_experts,
|
||||||
|
num_hidden_layers=fd_config.model_config.num_hidden_layers,
|
||||||
|
redundant_experts_num=fd_config.model_config.redundant_experts_num,
|
||||||
|
ep_size=self.ep_size,
|
||||||
|
)
|
||||||
self.quant_method.init_ep(self)
|
self.quant_method.init_ep(self)
|
||||||
|
|
||||||
if fd_config.load_config.dynamic_load_weight:
|
if fd_config.load_config.dynamic_load_weight:
|
||||||
@@ -222,12 +231,28 @@ class FusedMoE(nn.Layer):
|
|||||||
up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight.
|
up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight.
|
||||||
down_proj_expert_weight_key (str): The key of down_proj expert weight.
|
down_proj_expert_weight_key (str): The key of down_proj expert weight.
|
||||||
"""
|
"""
|
||||||
|
logical_expert_ids = [
|
||||||
|
i
|
||||||
|
for i in range(
|
||||||
|
self.expert_id_offset,
|
||||||
|
self.expert_id_offset + self.num_local_experts,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if self.redundant_table_manger is not None:
|
||||||
|
(
|
||||||
|
ep_rank_to_expert_id_list,
|
||||||
|
expert_id_to_ep_rank_array,
|
||||||
|
expert_in_rank_num_list,
|
||||||
|
tokens_per_expert_stats_list,
|
||||||
|
) = self.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(self.layer_idx)
|
||||||
|
logical_expert_ids = ep_rank_to_expert_id_list[
|
||||||
|
self.expert_id_offset : self.expert_id_offset + self.num_local_experts
|
||||||
|
]
|
||||||
up_gate_proj_weights = []
|
up_gate_proj_weights = []
|
||||||
down_proj_weights = []
|
down_proj_weights = []
|
||||||
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
|
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
|
||||||
if is_ffn_merged:
|
if is_ffn_merged:
|
||||||
for i in range(self.num_local_experts):
|
for expert_idx in logical_expert_ids:
|
||||||
expert_idx = self.expert_id_offset + i
|
|
||||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||||
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
|
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
|
||||||
up_gate_proj_weights.append(
|
up_gate_proj_weights.append(
|
||||||
@@ -253,8 +278,7 @@ class FusedMoE(nn.Layer):
|
|||||||
else:
|
else:
|
||||||
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj")
|
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj")
|
||||||
up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj")
|
up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj")
|
||||||
for j in range(self.num_local_experts):
|
for expert_idx in logical_expert_ids:
|
||||||
expert_idx = self.expert_id_offset + j
|
|
||||||
gate_expert_weight_key_name = gate_expert_weight_key.format(expert_idx)
|
gate_expert_weight_key_name = gate_expert_weight_key.format(expert_idx)
|
||||||
up_expert_weight_key_name = up_expert_weight_key.format(expert_idx)
|
up_expert_weight_key_name = up_expert_weight_key.format(expert_idx)
|
||||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||||
@@ -285,7 +309,7 @@ class FusedMoE(nn.Layer):
|
|||||||
self.fd_config.parallel_config.model_name_or_path,
|
self.fd_config.parallel_config.model_name_or_path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return up_gate_proj_weights, down_proj_weights
|
return up_gate_proj_weights, down_proj_weights, logical_expert_ids
|
||||||
|
|
||||||
def extract_moe_ffn_weights(self, state_dict: dict):
|
def extract_moe_ffn_weights(self, state_dict: dict):
|
||||||
"""
|
"""
|
||||||
@@ -308,7 +332,7 @@ class FusedMoE(nn.Layer):
|
|||||||
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
||||||
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights = self.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
@@ -329,33 +353,36 @@ class FusedMoE(nn.Layer):
|
|||||||
gate_correction_bias_tensor = get_tensor(state_dict.pop(gate_correction_bias_key)).astype("float32")
|
gate_correction_bias_tensor = get_tensor(state_dict.pop(gate_correction_bias_key)).astype("float32")
|
||||||
return gate_correction_bias_tensor
|
return gate_correction_bias_tensor
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
load_state_dict function.
|
load_state_dict function.
|
||||||
"""
|
"""
|
||||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
if not is_rearrange:
|
||||||
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
|
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||||
self.moe_use_gate_correction_bias = True
|
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
|
||||||
else:
|
self.moe_use_gate_correction_bias = True
|
||||||
self.moe_use_gate_correction_bias = False
|
else:
|
||||||
if self.moe_use_gate_correction_bias:
|
self.moe_use_gate_correction_bias = False
|
||||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(self.gate_correction_bias_key, state_dict)
|
if self.moe_use_gate_correction_bias:
|
||||||
self.gate_correction_bias = self.create_parameter(
|
gate_correction_bias_tensor = self.extract_gate_correction_bias(
|
||||||
shape=gate_correction_bias_tensor.shape,
|
self.gate_correction_bias_key, state_dict
|
||||||
|
)
|
||||||
|
self.gate_correction_bias = self.create_parameter(
|
||||||
|
shape=gate_correction_bias_tensor.shape,
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||||
|
|
||||||
|
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
|
||||||
|
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
|
||||||
|
|
||||||
|
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
|
||||||
|
|
||||||
|
self.gate_weight = self.create_parameter(
|
||||||
|
shape=gate_weight_tensor.shape,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
)
|
)
|
||||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
|
||||||
|
|
||||||
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
|
|
||||||
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
|
|
||||||
|
|
||||||
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
|
|
||||||
|
|
||||||
self.gate_weight = self.create_parameter(
|
|
||||||
shape=gate_weight_tensor.shape,
|
|
||||||
dtype="float32",
|
|
||||||
)
|
|
||||||
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
|
|
||||||
|
|
||||||
if self.fd_config.model_config.is_quantized:
|
if self.fd_config.model_config.is_quantized:
|
||||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||||
|
@@ -37,7 +37,7 @@ class RedundantExpertManger:
|
|||||||
ep_size: int,
|
ep_size: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a redundant expert manager"""
|
"""Initialize a redundant expert manager"""
|
||||||
self.num_expert = n_routed_experts
|
self.num_expert = n_routed_experts if isinstance(n_routed_experts, int) else n_routed_experts[0]
|
||||||
self.redundant_experts_num = redundant_experts_num
|
self.redundant_experts_num = redundant_experts_num
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user