support w4afp8 moe offline permute & load (#5613)

This commit is contained in:
Sunny-bot1
2025-12-22 15:12:57 +08:00
committed by GitHub
parent 81384ef29e
commit 40f3897a4e
3 changed files with 70 additions and 38 deletions

View File

@@ -772,8 +772,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
if not layer.moe_quant_config.moe_dynamic_quant:
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.load_experts_weight(
@@ -793,7 +794,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
if isinstance(state_dict, list):
state_dict = dict(state_dict)
if layer.ep_size > 1:
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
@@ -826,44 +827,54 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
layer.fd_config.model_config.model,
)
)
up_gate_proj_in_scale.append(
get_tensor(
(
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
if not layer.moe_quant_config.moe_dynamic_quant:
up_gate_proj_in_scale.append(
get_tensor(
(
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
)
down_proj_in_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
else down_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
down_proj_in_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
else down_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
if not layer.moe_quant_config.moe_dynamic_quant:
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
"up_gate_proj_in_scale": up_gate_proj_in_scale,
"down_proj_in_scale": down_proj_in_scale,
}
if not layer.moe_quant_config.moe_dynamic_quant:
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
"up_gate_proj_in_scale": up_gate_proj_in_scale,
"down_proj_in_scale": down_proj_in_scale,
}
else:
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
}
for name, tensor in name_tensor_map.items():
getattr(layer, name).set_value(tensor)
@@ -1020,11 +1031,27 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
# weight_scales
if layer.is_quantized:
if not layer.moe_quant_config.moe_dynamic_quant:
up_gate_proj_weight_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
down_proj_weight_scale_shape = [layer.num_local_experts, layer.hidden_size]
else:
up_gate_proj_weight_scale_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2 // 128,
layer.hidden_size // 128,
128,
]
down_proj_weight_scale_shape = [
layer.num_local_experts,
layer.hidden_size // 128,
layer.moe_intermediate_size // 128,
128,
]
setattr(
layer,
"up_gate_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
shape=up_gate_proj_weight_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
@@ -1033,7 +1060,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
layer,
"down_proj_weight_scale",
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
shape=down_proj_weight_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),

View File

@@ -212,8 +212,10 @@ class FusedMoE(nn.Layer):
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
self.is_quantized = fd_config.model_config.is_quantized and not (
fd_config.quant_config.name() == "mix_quant" and fd_config.quant_config.moe_quant_type is None
self.is_moe_quantized = getattr(self.fd_config.model_config, "is_moe_quantized", False)
self.is_quantized = self.is_moe_quantized or (
fd_config.model_config.is_quantized
and not (fd_config.quant_config.name() == "mix_quant" and fd_config.quant_config.moe_quant_type is None)
)
moe_quant_config = fd_config.quant_config
self.moe_quant_config = moe_quant_config

View File

@@ -40,6 +40,7 @@ class MixQuantConfig(QuantConfigBase):
is_quantized: bool = False,
hadamard_block_size: int = 128,
moe_dynamic_quant: bool = False,
is_moe_quantized: bool = False,
) -> None:
super().__init__()
self.dense_quant_type = dense_quant_type
@@ -59,6 +60,7 @@ class MixQuantConfig(QuantConfigBase):
self.is_quantized = is_quantized
self.hadamard_block_size = hadamard_block_size
self.moe_dynamic_quant = moe_dynamic_quant
self.is_moe_quantized = is_moe_quantized
def name(self) -> str:
return "mix_quant"
@@ -76,6 +78,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("is_quantized", False),
config.get("hadamard_block_size", 128),
config.get("moe_dynamic_quant", False),
config.get("is_moe_quantized", False),
)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -102,7 +105,7 @@ class MixQuantConfig(QuantConfigBase):
.from_config(
{
"is_permuted": self.is_permuted,
"is_quantized": not self.is_checkpoint_bf16,
"is_quantized": not self.is_checkpoint_bf16 or self.is_moe_quantized,
"hadamard_block_size": self.hadamard_block_size,
}
)