mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support_eplb (#2997)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [Feature] support_eplb * [Feature] support_eplb * [Fix] fix mm ep
This commit is contained in:
@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
@@ -117,7 +118,15 @@ class FusedMoE(nn.Layer):
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
self.quant_method = get_moe_method()
|
||||
|
||||
self.redundant_table_manger = None
|
||||
if self.ep_size > 1:
|
||||
if fd_config.model_config.enable_redundant_experts is True:
|
||||
self.redundant_table_manger = RedundantExpertManger(
|
||||
n_routed_experts=fd_config.model_config.moe_num_experts,
|
||||
num_hidden_layers=fd_config.model_config.num_hidden_layers,
|
||||
redundant_experts_num=fd_config.model_config.redundant_experts_num,
|
||||
ep_size=self.ep_size,
|
||||
)
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
@@ -222,12 +231,28 @@ class FusedMoE(nn.Layer):
|
||||
up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight.
|
||||
down_proj_expert_weight_key (str): The key of down_proj expert weight.
|
||||
"""
|
||||
logical_expert_ids = [
|
||||
i
|
||||
for i in range(
|
||||
self.expert_id_offset,
|
||||
self.expert_id_offset + self.num_local_experts,
|
||||
)
|
||||
]
|
||||
if self.redundant_table_manger is not None:
|
||||
(
|
||||
ep_rank_to_expert_id_list,
|
||||
expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list,
|
||||
) = self.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(self.layer_idx)
|
||||
logical_expert_ids = ep_rank_to_expert_id_list[
|
||||
self.expert_id_offset : self.expert_id_offset + self.num_local_experts
|
||||
]
|
||||
up_gate_proj_weights = []
|
||||
down_proj_weights = []
|
||||
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
|
||||
if is_ffn_merged:
|
||||
for i in range(self.num_local_experts):
|
||||
expert_idx = self.expert_id_offset + i
|
||||
for expert_idx in logical_expert_ids:
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
|
||||
up_gate_proj_weights.append(
|
||||
@@ -253,8 +278,7 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj")
|
||||
up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj")
|
||||
for j in range(self.num_local_experts):
|
||||
expert_idx = self.expert_id_offset + j
|
||||
for expert_idx in logical_expert_ids:
|
||||
gate_expert_weight_key_name = gate_expert_weight_key.format(expert_idx)
|
||||
up_expert_weight_key_name = up_expert_weight_key.format(expert_idx)
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
@@ -285,7 +309,7 @@ class FusedMoE(nn.Layer):
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
)
|
||||
)
|
||||
return up_gate_proj_weights, down_proj_weights
|
||||
return up_gate_proj_weights, down_proj_weights, logical_expert_ids
|
||||
|
||||
def extract_moe_ffn_weights(self, state_dict: dict):
|
||||
"""
|
||||
@@ -308,7 +332,7 @@ class FusedMoE(nn.Layer):
|
||||
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
||||
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
||||
|
||||
up_gate_proj_weights, down_proj_weights = self.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
@@ -329,33 +353,36 @@ class FusedMoE(nn.Layer):
|
||||
gate_correction_bias_tensor = get_tensor(state_dict.pop(gate_correction_bias_key)).astype("float32")
|
||||
return gate_correction_bias_tensor
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
load_state_dict function.
|
||||
"""
|
||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
|
||||
self.moe_use_gate_correction_bias = True
|
||||
else:
|
||||
self.moe_use_gate_correction_bias = False
|
||||
if self.moe_use_gate_correction_bias:
|
||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(self.gate_correction_bias_key, state_dict)
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_tensor.shape,
|
||||
if not is_rearrange:
|
||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
|
||||
self.moe_use_gate_correction_bias = True
|
||||
else:
|
||||
self.moe_use_gate_correction_bias = False
|
||||
if self.moe_use_gate_correction_bias:
|
||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(
|
||||
self.gate_correction_bias_key, state_dict
|
||||
)
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
|
||||
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
|
||||
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
|
||||
|
||||
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
|
||||
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
|
||||
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
|
||||
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
|
||||
|
||||
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
|
||||
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
|
||||
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
|
||||
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
|
||||
Reference in New Issue
Block a user