mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
【New Feature】W4afp8 supports per group quantization (#4987)
* w4afp8 支持per group * code style * fix transpose * revert fast hardmard --------- Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com> Co-authored-by: plusNew001 <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
@@ -275,6 +275,7 @@ class DeepEPEngine:
|
||||
topk_idx: paddle.Tensor,
|
||||
expertwise_scale,
|
||||
use_fp8: bool = False,
|
||||
quant_group_size: int = 128,
|
||||
):
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
@@ -294,6 +295,7 @@ class DeepEPEngine:
|
||||
use_fp8=use_fp8,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
num_per_channel=quant_group_size,
|
||||
)
|
||||
|
||||
return packed_recv_x, recv_expert_count, handle, dispatch_hook
|
||||
|
||||
@@ -30,7 +30,10 @@ if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
w4afp8_gemm_scale_permute,
|
||||
w4afp8_gemm_weight_convert,
|
||||
)
|
||||
except:
|
||||
logger.warning("import w4afp8_gemm_scale_permute Failed!")
|
||||
elif current_platform.is_iluvatar():
|
||||
@@ -81,6 +84,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
estimate_total_token_nums: int = -1,
|
||||
dequant_scale: paddle.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -106,7 +110,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
token_nums_per_expert,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
# None,
|
||||
dequant_scale,
|
||||
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
@@ -118,6 +122,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
|
||||
layer.activation,
|
||||
)
|
||||
|
||||
if layer.with_bias:
|
||||
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
|
||||
ffn_out_without_down_proj_bias = paddle.add(ffn_out_without_down_proj_bias, down_proj_bias_expand)
|
||||
@@ -157,6 +162,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
dst_indices,
|
||||
cumsum_idx_gpu,
|
||||
expert_idx_per_token,
|
||||
dequant_scale,
|
||||
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
@@ -173,11 +179,17 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
if hasattr(layer, "up_gate_proj_in_scale"):
|
||||
dequant_scale = None
|
||||
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
expert_idx_per_token,
|
||||
False,
|
||||
-1,
|
||||
dequant_scale,
|
||||
)
|
||||
|
||||
# prmt back per rank
|
||||
@@ -213,10 +225,19 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
|
||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||
use_fp8 = self.moe_quant_type == "w4afp8"
|
||||
quant_group_size = -1 if self.moe_quant_type == "w4afp8" else 128
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale, use_fp8=use_fp8
|
||||
x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
expertwise_scale=expertwise_scale,
|
||||
use_fp8=use_fp8,
|
||||
quant_group_size=quant_group_size,
|
||||
)
|
||||
dequant_scale = None
|
||||
if self.moe_quant_type == "w4afp8" and expertwise_scale is None:
|
||||
(permute_input, dequant_scale) = permute_input
|
||||
# 3. Compute ffn
|
||||
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
|
||||
num_local_experts, max_num, _ = permute_input.shape
|
||||
@@ -233,6 +254,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_idx_per_token,
|
||||
True,
|
||||
estimate_total_token_nums,
|
||||
dequant_scale,
|
||||
)
|
||||
|
||||
# 4. EP combine
|
||||
@@ -266,6 +288,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
dequant_scale,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
@@ -286,19 +309,21 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
dequant_scale,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(
|
||||
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
|
||||
), # if set, permute_input will be int8_t
|
||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
|
||||
layer.top_k,
|
||||
False,
|
||||
self.moe_quant_type,
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if hasattr(layer, "up_gate_proj_in_scale"):
|
||||
dequant_scale = None
|
||||
|
||||
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
@@ -306,7 +331,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
|
||||
ffn_out = self.compute_ffn(
|
||||
layer, permute_input, token_nums_per_expert, expert_idx_per_token, False, -1, dequant_scale
|
||||
)
|
||||
|
||||
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
@@ -851,7 +878,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
weight_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
|
||||
quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i])
|
||||
weight_list.append(quant_weight)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
@@ -869,7 +896,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
if layer.ep_size > 1:
|
||||
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_in_scale_all_experts",
|
||||
@@ -881,16 +908,17 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
)
|
||||
|
||||
# in_scales
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
setattr(
|
||||
layer,
|
||||
in_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
if not layer.moe_quant_config.moe_dynamic_quant:
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
setattr(
|
||||
layer,
|
||||
in_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
# weight_scales
|
||||
setattr(
|
||||
@@ -948,10 +976,57 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
return weight_scale
|
||||
|
||||
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
|
||||
processed_weight_scale = (
|
||||
paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None]
|
||||
)
|
||||
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
||||
if processed_in_scale is not None:
|
||||
processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9))
|
||||
if len(processed_weight_scale.shape) == 3:
|
||||
processed_weight_scale = (
|
||||
processed_weight_scale.transpose([0, 2, 1]) / processed_in_scale[:, None, None]
|
||||
)
|
||||
else:
|
||||
processed_weight_scale = processed_weight_scale / processed_in_scale[:, None]
|
||||
else:
|
||||
processed_weight_scale = paddle.stack(weight_scales, axis=0) / (440 * 7 * 2 ** (-9))
|
||||
|
||||
if len(processed_weight_scale.shape) == 3:
|
||||
if name == "up_gate_proj_weight_scale" and processed_weight_scale.shape[-1] * 128 != layer.hidden_size:
|
||||
assert (
|
||||
layer.hidden_size // 128 % processed_weight_scale.shape[-1] == 0
|
||||
), "weight_scale_group_size must be a multiple of 128"
|
||||
# If it is a multiple of 128, repeat to 128
|
||||
processed_weight_scale = processed_weight_scale.repeat_interleave(
|
||||
layer.hidden_size // 128 // processed_weight_scale.shape[-1], axis=-1
|
||||
)
|
||||
elif (
|
||||
name == "down_proj_weight_scale"
|
||||
and processed_weight_scale.shape[-1] * 128 != layer.moe_intermediate_size
|
||||
):
|
||||
assert (
|
||||
layer.moe_intermediate_size // 128 % processed_weight_scale.shape[-1] == 0
|
||||
), "weight_scale_group_size must be a multiple of 128"
|
||||
# If it is a multiple of 128, repeat to 128
|
||||
processed_weight_scale = processed_weight_scale.repeat_interleave(
|
||||
layer.moe_intermediate_size // 128 // processed_weight_scale.shape[-1], axis=-1
|
||||
)
|
||||
|
||||
origin_shape = processed_weight_scale.shape
|
||||
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1])
|
||||
processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]])
|
||||
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
||||
processed_weight_scale = processed_weight_scale.reshape(
|
||||
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
|
||||
)
|
||||
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3])
|
||||
setattr(
|
||||
layer,
|
||||
name,
|
||||
layer.create_parameter(
|
||||
shape=processed_weight_scale.shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
else:
|
||||
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
||||
getattr(layer, name).set_value(processed_weight_scale)
|
||||
|
||||
# 1. Init scale containers and maps
|
||||
@@ -978,7 +1053,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
|
||||
|
||||
# 2. Extract scale tensor from state dict
|
||||
if layer.ep_size > 1:
|
||||
if layer.ep_size > 1 and not layer.moe_quant_config.moe_dynamic_quant:
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
@@ -998,16 +1073,15 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
|
||||
# 3. Process scale tensor and set to layer
|
||||
in_scales = []
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name]))
|
||||
|
||||
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
||||
in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale")
|
||||
in_scale = None
|
||||
if hasattr(layer, in_scale_name) and in_scale_name in scale_weight_map.keys():
|
||||
in_scale = _process_in_scale(in_scale_name, scale_weight_map[in_scale_name])
|
||||
_process_weight_scale(
|
||||
weight_scale_name,
|
||||
scale_weight_map[weight_scale_name],
|
||||
in_scales[i],
|
||||
in_scale,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -275,6 +275,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
dequant_scale,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
|
||||
@@ -39,6 +39,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
is_permuted: bool = True,
|
||||
is_quantized: bool = False,
|
||||
hadamard_block_size: int = 128,
|
||||
moe_dynamic_quant: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
@@ -55,7 +56,9 @@ class MixQuantConfig(QuantConfigBase):
|
||||
self.quant_round_type = 0
|
||||
self.is_permuted = is_permuted
|
||||
self.is_checkpoint_bf16 = not is_quantized
|
||||
self.is_quantized = is_quantized
|
||||
self.hadamard_block_size = hadamard_block_size
|
||||
self.moe_dynamic_quant = moe_dynamic_quant
|
||||
|
||||
def name(self) -> str:
|
||||
return "mix_quant"
|
||||
@@ -72,6 +75,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
config.get("is_permuted", True),
|
||||
config.get("is_quantized", False),
|
||||
config.get("hadamard_block_size", 128),
|
||||
config.get("moe_dynamic_quant", False),
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
|
||||
Reference in New Issue
Block a user