mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[NewFeatures] support eplb (#3547)
* [NewFeatures] support eplb * fix eplb
This commit is contained in:
@@ -38,7 +38,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"down_proj_weight_scale",
|
"down_proj_weight_scale",
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
"""process_prequanted_weights"""
|
"""process_prequanted_weights"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
Triton MoE create weight process.
|
Triton MoE create weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||||
assert len(down_proj_weights) == layer.num_local_experts
|
assert len(down_proj_weights) == layer.num_local_experts
|
||||||
assert self.quant_method.name() == "wint8"
|
assert self.quant_method.name() == "wint8"
|
||||||
|
@@ -49,7 +49,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
|||||||
self.group_size = -1
|
self.group_size = -1
|
||||||
|
|
||||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||||
layer.up_gate_proj_weight.set_value(paddle.transpose(stacked_up_gate_proj_weights, [0, 2, 1]))
|
layer.up_gate_proj_weight.set_value(paddle.transpose(stacked_up_gate_proj_weights, [0, 2, 1]))
|
||||||
@@ -254,7 +254,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
|||||||
self.quant_multi_process_group_size = int(os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8))
|
self.quant_multi_process_group_size = int(os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8))
|
||||||
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
|
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle gcu process prequanted weights.
|
Paddle gcu process prequanted weights.
|
||||||
"""
|
"""
|
||||||
@@ -299,7 +299,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
|||||||
"""
|
"""
|
||||||
Paddle cutlass create weight process.
|
Paddle cutlass create weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
|
|
||||||
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
||||||
|
@@ -59,7 +59,7 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
is_bias=False,
|
is_bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Process pre-quantized weights before applying them to the model
|
Process pre-quantized weights before applying them to the model
|
||||||
Args:
|
Args:
|
||||||
|
@@ -41,7 +41,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"down_proj_weight_scale",
|
"down_proj_weight_scale",
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
"""process_prequanted_weights"""
|
"""process_prequanted_weights"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
Triton MoE create weight process.
|
Triton MoE create weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||||
assert len(down_proj_weights) == layer.num_local_experts
|
assert len(down_proj_weights) == layer.num_local_experts
|
||||||
|
|
||||||
|
@@ -74,7 +74,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||||
|
layer.extract_moe_ffn_weights(state_dict)
|
||||||
|
)
|
||||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||||
|
|
||||||
@@ -333,7 +335,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
self.moe_quant_type = "w4a8"
|
self.moe_quant_type = "w4a8"
|
||||||
self.pack_num = 2
|
self.pack_num = 2
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass process prequanted weights.
|
Paddle cutlass process prequanted weights.
|
||||||
"""
|
"""
|
||||||
@@ -349,6 +351,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
|
is_rearrange,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -358,22 +361,62 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
up_gate_proj_in_scale = []
|
up_gate_proj_in_scale = []
|
||||||
down_proj_in_scale = []
|
down_proj_in_scale = []
|
||||||
|
|
||||||
|
if isinstance(state_dict, list):
|
||||||
|
state_dict = dict(state_dict)
|
||||||
|
|
||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
for expert_idx in ep_rank_to_expert_id_list:
|
for expert_idx in ep_rank_to_expert_id_list:
|
||||||
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
scale_tensor = get_tensor(
|
||||||
|
(
|
||||||
|
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
|
||||||
|
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||||
|
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||||
|
),
|
||||||
|
layer.fd_config.model_config.model,
|
||||||
|
)
|
||||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||||
|
|
||||||
for expert_idx in logical_expert_ids:
|
for expert_idx in logical_expert_ids:
|
||||||
up_gate_proj_weight_scale.append(
|
up_gate_proj_weight_scale.append(
|
||||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
get_tensor(
|
||||||
|
(
|
||||||
|
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
|
||||||
|
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||||
|
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
|
||||||
|
),
|
||||||
|
layer.fd_config.model_config.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
down_proj_weight_scale.append(
|
down_proj_weight_scale.append(
|
||||||
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
get_tensor(
|
||||||
|
(
|
||||||
|
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
|
||||||
|
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||||
|
else down_proj_expert_weight_scale_key.format(expert_idx)
|
||||||
|
),
|
||||||
|
layer.fd_config.model_config.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
up_gate_proj_in_scale.append(
|
up_gate_proj_in_scale.append(
|
||||||
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
|
get_tensor(
|
||||||
|
(
|
||||||
|
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
|
||||||
|
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||||
|
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||||
|
),
|
||||||
|
layer.fd_config.model_config.model,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
down_proj_in_scale.append(
|
||||||
|
get_tensor(
|
||||||
|
(
|
||||||
|
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
|
||||||
|
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||||
|
else down_proj_expert_in_scale_key.format(expert_idx)
|
||||||
|
),
|
||||||
|
layer.fd_config.model_config.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
|
|
||||||
|
|
||||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||||
@@ -435,7 +478,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
"""
|
"""
|
||||||
Paddle cutlass load weight process.
|
Paddle cutlass load weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||||
|
layer.extract_moe_ffn_weights(state_dict)
|
||||||
|
)
|
||||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||||
weight_name = self.added_weight_attrs[idx]
|
weight_name = self.added_weight_attrs[idx]
|
||||||
@@ -446,7 +491,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||||
getattr(layer, weight_name).set_value(quanted_weight)
|
getattr(layer, weight_name).set_value(quanted_weight)
|
||||||
|
|
||||||
self.load_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
|
self.load_w4a8_scale_weights(
|
||||||
|
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
|
||||||
|
)
|
||||||
|
|
||||||
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
|
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
|
||||||
"""
|
"""
|
||||||
@@ -499,7 +546,14 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
|
def load_w4a8_scale_weights(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
weight_key_map: dict,
|
||||||
|
state_dict: dict,
|
||||||
|
logical_expert_ids: paddle.Tensor,
|
||||||
|
ep_rank_to_expert_id_list: list,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get w4a8 weights from state dict and process them.
|
Get w4a8 weights from state dict and process them.
|
||||||
Args:
|
Args:
|
||||||
@@ -508,8 +562,15 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
state_dict (dict): The state dict.
|
state_dict (dict): The state dict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _extract_scale_tensor(state_dict, key_template, expert_idx):
|
def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx):
|
||||||
return get_tensor(state_dict.pop(key_template.format(expert_idx)))
|
return get_tensor(
|
||||||
|
(
|
||||||
|
state_dict.pop(key_template.format(expert_idx))
|
||||||
|
if key_template.format(expert_idx) in state_dict
|
||||||
|
else key_template.format(expert_idx)
|
||||||
|
),
|
||||||
|
layer.fd_config.model_config.model,
|
||||||
|
)
|
||||||
|
|
||||||
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
||||||
processed_in_scale = 1 / paddle.concat(in_scales)
|
processed_in_scale = 1 / paddle.concat(in_scales)
|
||||||
@@ -551,17 +612,23 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
|
|
||||||
# 2. Extract scale tensor from state dict
|
# 2. Extract scale tensor from state dict
|
||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
for expert_idx in range(layer.num_experts):
|
for expert_idx in ep_rank_to_expert_id_list:
|
||||||
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
|
scale_tensor = get_tensor(
|
||||||
|
(
|
||||||
|
state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]
|
||||||
|
if scale_key_map["up_gate_proj_in_scale"].format(expert_idx) in state_dict
|
||||||
|
else scale_key_map["up_gate_proj_in_scale"].format(expert_idx)
|
||||||
|
),
|
||||||
|
layer.fd_config.model_config.model,
|
||||||
|
)
|
||||||
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
|
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
|
||||||
getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
|
getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
|
||||||
paddle.concat(up_gate_proj_in_scales_all_experts)
|
paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||||
)
|
)
|
||||||
|
|
||||||
for local_expert_idx in range(layer.num_local_experts):
|
for expert_idx in logical_expert_ids:
|
||||||
expert_idx = local_expert_idx + layer.expert_id_offset
|
|
||||||
for name, scale_key_template in scale_key_map.items():
|
for name, scale_key_template in scale_key_map.items():
|
||||||
scale_tensor = _extract_scale_tensor(state_dict, scale_key_template, expert_idx)
|
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||||
scale_weight_map[name].append(scale_tensor)
|
scale_weight_map[name].append(scale_tensor)
|
||||||
|
|
||||||
# 3. Process scale tensor and set to layer
|
# 3. Process scale tensor and set to layer
|
||||||
@@ -845,7 +912,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
self.moe_quant_type = self.quant_config.algo
|
self.moe_quant_type = self.quant_config.algo
|
||||||
self.pack_num = 1
|
self.pack_num = 1
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass process prequanted weights.
|
Paddle cutlass process prequanted weights.
|
||||||
"""
|
"""
|
||||||
@@ -855,9 +922,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||||
up_gate_proj_expert_weight_key,
|
|
||||||
down_proj_expert_weight_key,
|
|
||||||
)
|
)
|
||||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
up_gate_proj_weight_scale = []
|
up_gate_proj_weight_scale = []
|
||||||
@@ -1065,7 +1130,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
"""
|
"""
|
||||||
Paddle cutlass load weight process.
|
Paddle cutlass load weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||||
weight_name = self.added_weight_attrs[idx]
|
weight_name = self.added_weight_attrs[idx]
|
||||||
|
@@ -99,7 +99,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
"""
|
"""
|
||||||
deepgemm create weight process.
|
deepgemm create weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
|
|
||||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
|
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
|
||||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass process prequanted weights.
|
Paddle cutlass process prequanted weights.
|
||||||
"""
|
"""
|
||||||
@@ -134,9 +134,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||||
up_gate_proj_expert_weight_key,
|
|
||||||
down_proj_expert_weight_key,
|
|
||||||
)
|
)
|
||||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
up_gate_proj_weight_scale = []
|
up_gate_proj_weight_scale = []
|
||||||
|
@@ -197,7 +197,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
Marlin MoE load weight process.
|
Marlin MoE load weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||||
assert len(down_proj_weights) == layer.num_local_experts
|
assert len(down_proj_weights) == layer.num_local_experts
|
||||||
assert up_gate_proj_weights[0].shape == [
|
assert up_gate_proj_weights[0].shape == [
|
||||||
|
@@ -48,7 +48,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"down_proj_weight_scale",
|
"down_proj_weight_scale",
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
"""process_prequanted_weights"""
|
"""process_prequanted_weights"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -112,7 +112,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
Triton MoE load weight process.
|
Triton MoE load weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||||
assert len(down_proj_weights) == layer.num_local_experts
|
assert len(down_proj_weights) == layer.num_local_experts
|
||||||
|
|
||||||
@@ -311,7 +311,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
"down_proj_in_scale",
|
"down_proj_in_scale",
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
"""process_prequanted_weights"""
|
"""process_prequanted_weights"""
|
||||||
|
|
||||||
up_gate_proj_tensor, down_proj_tensor = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_tensor, down_proj_tensor = layer.extract_moe_ffn_weights(state_dict)
|
||||||
@@ -595,7 +595,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
"down_proj_weight_scale",
|
"down_proj_weight_scale",
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
"""process_prequanted_weights"""
|
"""process_prequanted_weights"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -667,7 +667,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
Triton MoE create weight process.
|
Triton MoE create weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
|
|
||||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||||
|
@@ -73,7 +73,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass process prequanted weights.
|
Paddle cutlass process prequanted weights.
|
||||||
"""
|
"""
|
||||||
|
@@ -34,7 +34,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
|
|
||||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
for weights in [up_gate_proj_weights, down_proj_weights]:
|
for weights in [up_gate_proj_weights, down_proj_weights]:
|
||||||
for idx, weight in enumerate(weights):
|
for idx, weight in enumerate(weights):
|
||||||
weights[idx] = weight.transpose([1, 0])
|
weights[idx] = weight.transpose([1, 0])
|
||||||
@@ -119,7 +119,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
Paddle cutlass create weight process.
|
Paddle cutlass create weight process.
|
||||||
"""
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||||
assert len(down_proj_weights) == layer.num_local_experts
|
assert len(down_proj_weights) == layer.num_local_experts
|
||||||
assert up_gate_proj_weights[0].shape == [
|
assert up_gate_proj_weights[0].shape == [
|
||||||
|
@@ -80,6 +80,7 @@ class FusedMoE(nn.Layer):
|
|||||||
layer_idx: int = -1,
|
layer_idx: int = -1,
|
||||||
moe_tag: str = "",
|
moe_tag: str = "",
|
||||||
gate_correction_bias=None,
|
gate_correction_bias=None,
|
||||||
|
redundant_table_manger: RedundantExpertManger = None,
|
||||||
weight_key_map: dict = {},
|
weight_key_map: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -147,15 +148,8 @@ class FusedMoE(nn.Layer):
|
|||||||
self.moe_quant_type = moe_quant_config.name()
|
self.moe_quant_type = moe_quant_config.name()
|
||||||
else:
|
else:
|
||||||
self.quant_method = get_moe_method()
|
self.quant_method = get_moe_method()
|
||||||
self.redundant_table_manger = None
|
self.redundant_table_manger = redundant_table_manger
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
if fd_config.model_config.enable_redundant_experts is True:
|
|
||||||
self.redundant_table_manger = RedundantExpertManger(
|
|
||||||
n_routed_experts=fd_config.model_config.moe_num_experts,
|
|
||||||
num_hidden_layers=fd_config.model_config.num_hidden_layers,
|
|
||||||
redundant_experts_num=fd_config.model_config.redundant_experts_num,
|
|
||||||
ep_size=self.ep_size,
|
|
||||||
)
|
|
||||||
self.quant_method.init_ep(self)
|
self.quant_method.init_ep(self)
|
||||||
|
|
||||||
if fd_config.load_config.dynamic_load_weight:
|
if fd_config.load_config.dynamic_load_weight:
|
||||||
@@ -423,6 +417,7 @@ class FusedMoE(nn.Layer):
|
|||||||
state_dict: dict,
|
state_dict: dict,
|
||||||
up_gate_proj_expert_weight_key: str,
|
up_gate_proj_expert_weight_key: str,
|
||||||
down_proj_expert_weight_key: str,
|
down_proj_expert_weight_key: str,
|
||||||
|
is_rearrange: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load experts weight from state_dict.
|
Load experts weight from state_dict.
|
||||||
@@ -451,7 +446,12 @@ class FusedMoE(nn.Layer):
|
|||||||
]
|
]
|
||||||
up_gate_proj_weights = []
|
up_gate_proj_weights = []
|
||||||
down_proj_weights = []
|
down_proj_weights = []
|
||||||
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
|
if isinstance(state_dict, list):
|
||||||
|
state_dict = dict(state_dict)
|
||||||
|
is_ffn_merged = (
|
||||||
|
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
|
||||||
|
in state_dict
|
||||||
|
)
|
||||||
if is_ffn_merged:
|
if is_ffn_merged:
|
||||||
for expert_idx in logical_expert_ids:
|
for expert_idx in logical_expert_ids:
|
||||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||||
@@ -533,11 +533,13 @@ class FusedMoE(nn.Layer):
|
|||||||
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
||||||
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = self.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||||
|
self.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
len(up_gate_proj_weights) == self.num_local_experts
|
len(up_gate_proj_weights) == self.num_local_experts
|
||||||
), "up_gate_proj_weights length should be equal to num_local_experts."
|
), "up_gate_proj_weights length should be equal to num_local_experts."
|
||||||
@@ -545,7 +547,7 @@ class FusedMoE(nn.Layer):
|
|||||||
len(down_proj_weights) == self.num_local_experts
|
len(down_proj_weights) == self.num_local_experts
|
||||||
), "down_proj_weights length should be equal to num_local_experts."
|
), "down_proj_weights length should be equal to num_local_experts."
|
||||||
|
|
||||||
return up_gate_proj_weights, down_proj_weights
|
return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
|
||||||
|
|
||||||
def extract_gate_correction_bias(self, gate_correction_bias_key, state_dict):
|
def extract_gate_correction_bias(self, gate_correction_bias_key, state_dict):
|
||||||
"""
|
"""
|
||||||
@@ -561,7 +563,7 @@ class FusedMoE(nn.Layer):
|
|||||||
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
|
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
|
||||||
if self.fd_config.model_config.is_quantized:
|
if self.fd_config.model_config.is_quantized:
|
||||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
|
||||||
else:
|
else:
|
||||||
self.quant_method.process_loaded_weights(self, state_dict)
|
self.quant_method.process_loaded_weights(self, state_dict)
|
||||||
else:
|
else:
|
||||||
@@ -569,7 +571,7 @@ class FusedMoE(nn.Layer):
|
|||||||
else:
|
else:
|
||||||
if self.fd_config.model_config.is_quantized:
|
if self.fd_config.model_config.is_quantized:
|
||||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
|
||||||
else:
|
else:
|
||||||
self.quant_method.create_weights(self, state_dict)
|
self.quant_method.create_weights(self, state_dict)
|
||||||
else:
|
else:
|
||||||
|
@@ -108,7 +108,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
layer.weight.copy_(quanted_weight_tensor, False)
|
layer.weight.copy_(quanted_weight_tensor, False)
|
||||||
layer.weight_scale.set_value(weight_block_scale_tensor)
|
layer.weight_scale.set_value(weight_block_scale_tensor)
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer, state_dict):
|
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
process_prequanted_weights
|
process_prequanted_weights
|
||||||
"""
|
"""
|
||||||
|
@@ -90,7 +90,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Process pre-quantized weights before applying them to the model
|
Process pre-quantized weights before applying them to the model
|
||||||
Args:
|
Args:
|
||||||
|
@@ -305,7 +305,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Process pre-quantized weights before applying them to the model
|
Process pre-quantized weights before applying them to the model
|
||||||
Args:
|
Args:
|
||||||
|
@@ -127,7 +127,11 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
|||||||
num_local_ffn_keys.append(down_proj_in_scale_key)
|
num_local_ffn_keys.append(down_proj_in_scale_key)
|
||||||
|
|
||||||
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
|
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
|
||||||
for j in range(fd_config.model_config.moe_num_experts):
|
num_experts = fd_config.model_config.moe_num_experts
|
||||||
|
if isinstance(num_experts, list):
|
||||||
|
num_experts = num_experts[0]
|
||||||
|
|
||||||
|
for j in range(num_experts):
|
||||||
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
|
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
|
||||||
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
|
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
|
||||||
|
|
||||||
|
@@ -49,6 +49,7 @@ from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
|||||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||||
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
||||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||||
|
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_MLP(nn.Layer):
|
class Ernie4_5_MLP(nn.Layer):
|
||||||
@@ -97,7 +98,9 @@ class Ernie4_5_MLP(nn.Layer):
|
|||||||
|
|
||||||
|
|
||||||
class Ernie4_5_MoE(nn.Layer):
|
class Ernie4_5_MoE(nn.Layer):
|
||||||
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
|
def __init__(
|
||||||
|
self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
moe_quant_type = ""
|
moe_quant_type = ""
|
||||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||||
@@ -175,6 +178,7 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
top_k=fd_config.model_config.moe_k,
|
top_k=fd_config.model_config.moe_k,
|
||||||
layer_idx=layer_id,
|
layer_idx=layer_id,
|
||||||
gate_correction_bias=None,
|
gate_correction_bias=None,
|
||||||
|
redundant_table_manger=redundant_table_manger,
|
||||||
weight_key_map=weight_key_map,
|
weight_key_map=weight_key_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -209,6 +213,9 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
self.shared_experts.load_state_dict(state_dict)
|
self.shared_experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def update_state_dict(self, state_dict):
|
||||||
|
self.fused_moe.load_state_dict(state_dict, True)
|
||||||
|
|
||||||
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
|
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
|
||||||
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
|
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
|
||||||
# AllGather will hang when the data shapes on multi-ranks are different!
|
# AllGather will hang when the data shapes on multi-ranks are different!
|
||||||
@@ -287,6 +294,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fd_config: FDConfig,
|
fd_config: FDConfig,
|
||||||
|
redundant_table_manger: RedundantExpertManger = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -305,6 +313,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
|||||||
self.mlp = Ernie4_5_MoE(
|
self.mlp = Ernie4_5_MoE(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
redundant_table_manger=redundant_table_manger,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -334,6 +343,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
|||||||
self.input_layernorm.load_state_dict(state_dict)
|
self.input_layernorm.load_state_dict(state_dict)
|
||||||
self.post_attention_layernorm.load_state_dict(state_dict)
|
self.post_attention_layernorm.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def update_state_dict(self, state_dict):
|
||||||
|
self.mlp.update_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
forward_meta: ForwardMeta,
|
forward_meta: ForwardMeta,
|
||||||
@@ -374,6 +386,15 @@ class Ernie4_5_Model(nn.Layer):
|
|||||||
|
|
||||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||||
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
||||||
|
self.fd_config = fd_config
|
||||||
|
self.redundant_table_manger = None
|
||||||
|
if fd_config.model_config.enable_redundant_experts is True:
|
||||||
|
self.redundant_table_manger = RedundantExpertManger(
|
||||||
|
n_routed_experts=fd_config.model_config.moe_num_experts,
|
||||||
|
num_hidden_layers=fd_config.model_config.num_hidden_layers,
|
||||||
|
redundant_experts_num=fd_config.model_config.redundant_experts_num,
|
||||||
|
ep_size=fd_config.parallel_config.expert_parallel_size,
|
||||||
|
)
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
@@ -387,6 +408,7 @@ class Ernie4_5_Model(nn.Layer):
|
|||||||
[
|
[
|
||||||
Ernie4_5_DecoderLayer(
|
Ernie4_5_DecoderLayer(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
|
redundant_table_manger=self.redundant_table_manger,
|
||||||
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
|
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
|
||||||
)
|
)
|
||||||
for i in range(self.num_layers)
|
for i in range(self.num_layers)
|
||||||
@@ -415,6 +437,22 @@ class Ernie4_5_Model(nn.Layer):
|
|||||||
logger.info(f"Start load layer {i}")
|
logger.info(f"Start load layer {i}")
|
||||||
self.layers[i].load_state_dict(state_dict)
|
self.layers[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def update_state_dict(self, state_dict):
|
||||||
|
"""
|
||||||
|
Update model parameters from a given state dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||||
|
A dictionary containing model parameters, where keys are parameter names
|
||||||
|
and values are NumPy arrays or PaddlePaddle tensors.
|
||||||
|
"""
|
||||||
|
for i in range(
|
||||||
|
self.fd_config.model_config.moe_layer_start_index,
|
||||||
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
|
):
|
||||||
|
logger.info(f"Start update layer {i}")
|
||||||
|
self.layers[i].update_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ids_remove_padding: paddle.Tensor,
|
ids_remove_padding: paddle.Tensor,
|
||||||
|
@@ -86,8 +86,8 @@ class Ernie4_5_VLMoeBlock(nn.Layer):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
moe_quant_type = ""
|
moe_quant_type = ""
|
||||||
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
|
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||||
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
|
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||||
|
|
||||||
if moe_quant_type == "tensor_wise_fp8" or (
|
if moe_quant_type == "tensor_wise_fp8" or (
|
||||||
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
|
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
|
||||||
|
Reference in New Issue
Block a user