【New Feature】W4afp8 supports per group quantization (#4987)

* w4afp8 支持per group

* code style

* fix transpose

* revert fast hardmard

---------

Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
Co-authored-by: plusNew001 <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
yangjianfengo1
2025-11-13 19:17:27 +08:00
committed by GitHub
parent a5e949d9d0
commit ae7bee8122
21 changed files with 3114 additions and 2248 deletions

View File

@@ -275,6 +275,7 @@ class DeepEPEngine:
topk_idx: paddle.Tensor,
expertwise_scale,
use_fp8: bool = False,
quant_group_size: int = 128,
):
if self.deepep_engine is None:
raise RuntimeError("DeepEP buffer not initialized!")
@@ -294,6 +295,7 @@ class DeepEPEngine:
use_fp8=use_fp8,
async_finish=False,
return_recv_hook=True,
num_per_channel=quant_group_size,
)
return packed_recv_x, recv_expert_count, handle, dispatch_hook

View File

@@ -30,7 +30,10 @@ if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
try:
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
from fastdeploy.model_executor.ops.gpu import (
w4afp8_gemm_scale_permute,
w4afp8_gemm_weight_convert,
)
except:
logger.warning("import w4afp8_gemm_scale_permute Failed!")
elif current_platform.is_iluvatar():
@@ -81,6 +84,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1,
dequant_scale: paddle.Tensor = None,
):
"""
Paddle Cutlass compute Fused MoE.
@@ -106,7 +110,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
# None,
dequant_scale,
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
@@ -118,6 +122,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
layer.activation,
)
if layer.with_bias:
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
ffn_out_without_down_proj_bias = paddle.add(ffn_out_without_down_proj_bias, down_proj_bias_expand)
@@ -157,6 +162,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
dst_indices,
cumsum_idx_gpu,
expert_idx_per_token,
dequant_scale,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
recv_x,
recv_topk_idx,
@@ -173,11 +179,17 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
ffn_out = self.compute_ffn(
layer,
permute_input,
recv_num_tokens_per_expert_list_cumsum,
expert_idx_per_token,
False,
-1,
dequant_scale,
)
# prmt back per rank
@@ -213,10 +225,19 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
use_fp8 = self.moe_quant_type == "w4afp8"
quant_group_size = -1 if self.moe_quant_type == "w4afp8" else 128
# 2. EP Dispatch
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale, use_fp8=use_fp8
x,
topk_idx,
topk_weights,
expertwise_scale=expertwise_scale,
use_fp8=use_fp8,
quant_group_size=quant_group_size,
)
dequant_scale = None
if self.moe_quant_type == "w4afp8" and expertwise_scale is None:
(permute_input, dequant_scale) = permute_input
# 3. Compute ffn
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
num_local_experts, max_num, _ = permute_input.shape
@@ -233,6 +254,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
expert_idx_per_token,
True,
estimate_total_token_nums,
dequant_scale,
)
# 4. EP combine
@@ -266,6 +288,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
) = moe_expert_dispatch(
x,
gate_out,
@@ -286,19 +309,21 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
), # if set, permute_input will be int8_t
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
@@ -306,7 +331,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
ffn_out = self.compute_ffn(
layer, permute_input, token_nums_per_expert, expert_idx_per_token, False, -1, dequant_scale
)
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
fused_moe_out = moe_expert_reduce(
@@ -851,7 +878,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
weight_name = self.added_weight_attrs[idx]
weight_list = []
for i in range(layer.num_local_experts):
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i])
weight_list.append(quant_weight)
quanted_weight = paddle.stack(weight_list, axis=0)
getattr(layer, weight_name).set_value(quanted_weight)
@@ -869,7 +896,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
"""
self.default_dtype = layer._helper.get_default_dtype()
if layer.ep_size > 1:
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
setattr(
layer,
"up_gate_proj_in_scale_all_experts",
@@ -881,16 +908,17 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
)
# in_scales
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
setattr(
layer,
in_scale_name,
layer.create_parameter(
shape=[layer.num_local_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
if not layer.moe_quant_config.moe_dynamic_quant:
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
setattr(
layer,
in_scale_name,
layer.create_parameter(
shape=[layer.num_local_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scales
setattr(
@@ -948,10 +976,57 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
return weight_scale
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
processed_weight_scale = (
paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None]
)
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
if processed_in_scale is not None:
processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9))
if len(processed_weight_scale.shape) == 3:
processed_weight_scale = (
processed_weight_scale.transpose([0, 2, 1]) / processed_in_scale[:, None, None]
)
else:
processed_weight_scale = processed_weight_scale / processed_in_scale[:, None]
else:
processed_weight_scale = paddle.stack(weight_scales, axis=0) / (440 * 7 * 2 ** (-9))
if len(processed_weight_scale.shape) == 3:
if name == "up_gate_proj_weight_scale" and processed_weight_scale.shape[-1] * 128 != layer.hidden_size:
assert (
layer.hidden_size // 128 % processed_weight_scale.shape[-1] == 0
), "weight_scale_group_size must be a multiple of 128"
# If it is a multiple of 128, repeat to 128
processed_weight_scale = processed_weight_scale.repeat_interleave(
layer.hidden_size // 128 // processed_weight_scale.shape[-1], axis=-1
)
elif (
name == "down_proj_weight_scale"
and processed_weight_scale.shape[-1] * 128 != layer.moe_intermediate_size
):
assert (
layer.moe_intermediate_size // 128 % processed_weight_scale.shape[-1] == 0
), "weight_scale_group_size must be a multiple of 128"
# If it is a multiple of 128, repeat to 128
processed_weight_scale = processed_weight_scale.repeat_interleave(
layer.moe_intermediate_size // 128 // processed_weight_scale.shape[-1], axis=-1
)
origin_shape = processed_weight_scale.shape
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1])
processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]])
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
processed_weight_scale = processed_weight_scale.reshape(
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
)
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3])
setattr(
layer,
name,
layer.create_parameter(
shape=processed_weight_scale.shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
else:
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
getattr(layer, name).set_value(processed_weight_scale)
# 1. Init scale containers and maps
@@ -978,7 +1053,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
# 2. Extract scale tensor from state dict
if layer.ep_size > 1:
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
@@ -998,16 +1073,15 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
scale_weight_map[name].append(scale_tensor)
# 3. Process scale tensor and set to layer
in_scales = []
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name]))
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale")
in_scale = None
if hasattr(layer, in_scale_name) and in_scale_name in scale_weight_map.keys():
in_scale = _process_in_scale(in_scale_name, scale_weight_map[in_scale_name])
_process_weight_scale(
weight_scale_name,
scale_weight_map[weight_scale_name],
in_scales[i],
in_scale,
)

View File

@@ -275,6 +275,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
) = moe_expert_dispatch(
x,
gate_out,

View File

@@ -39,6 +39,7 @@ class MixQuantConfig(QuantConfigBase):
is_permuted: bool = True,
is_quantized: bool = False,
hadamard_block_size: int = 128,
moe_dynamic_quant: bool = False,
) -> None:
super().__init__()
self.dense_quant_type = dense_quant_type
@@ -55,7 +56,9 @@ class MixQuantConfig(QuantConfigBase):
self.quant_round_type = 0
self.is_permuted = is_permuted
self.is_checkpoint_bf16 = not is_quantized
self.is_quantized = is_quantized
self.hadamard_block_size = hadamard_block_size
self.moe_dynamic_quant = moe_dynamic_quant
def name(self) -> str:
return "mix_quant"
@@ -72,6 +75,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("is_permuted", True),
config.get("is_quantized", False),
config.get("hadamard_block_size", 128),
config.get("moe_dynamic_quant", False),
)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: