load hadamard_block_size from config (#3797)

This commit is contained in:
Yuan Xiaolan
2025-09-05 17:07:58 +08:00
committed by GitHub
parent 41aee08982
commit 2cf55168ca
10 changed files with 60 additions and 30 deletions

View File

@@ -127,6 +127,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
self.moe_quant_type,
used_in_ep_low_latency,
estimate_total_token_nums,
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
)
def apply_ep_prefill(

View File

@@ -38,6 +38,7 @@ class MixQuantConfig(QuantConfigBase):
has_zero_point: bool = False,
is_permuted: bool = True,
is_checkpoint_bf16: bool = False,
hadamard_block_size: int = 128,
) -> None:
super().__init__()
self.dense_quant_type = dense_quant_type
@@ -54,6 +55,7 @@ class MixQuantConfig(QuantConfigBase):
self.quant_round_type = 0
self.is_permuted = is_permuted
self.is_checkpoint_bf16 = is_checkpoint_bf16
self.hadamard_block_size = hadamard_block_size
def name(self) -> str:
return "mix_quant"
@@ -69,6 +71,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("has_zero_point", False),
config.get("is_permuted", True),
config.get("is_checkpoint_bf16", False),
config.get("hadamard_block_size", 128),
)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -76,13 +79,25 @@ class MixQuantConfig(QuantConfigBase):
if layer.moe_tag == "Image":
return (
get_quantization_config(self.image_moe_quant_type)
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.from_config(
{
"is_permuted": self.is_permuted,
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
"hadamard_block_size": self.hadamard_block_size,
}
)
.get_quant_method(layer)
)
else:
return (
get_quantization_config(self.moe_quant_type)
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.from_config(
{
"is_permuted": self.is_permuted,
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
"hadamard_block_size": self.hadamard_block_size,
}
)
.get_quant_method(layer)
)
elif isinstance(layer, Attention):

View File

@@ -25,9 +25,10 @@ class W4A8Config(QuantConfigBase):
quantization config for weight 4bits and activation 8bits
"""
def __init__(self, is_permuted) -> None:
def __init__(self, is_permuted, hadamard_block_size) -> None:
super().__init__()
self.is_permuted = is_permuted
self.hadamard_block_size = hadamard_block_size
def name(self) -> str:
return "w4a8"
@@ -35,7 +36,8 @@ class W4A8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "W4A8Config":
is_permuted = config.get("is_permuted", True)
return cls(is_permuted)
hadamard_block_size = config.get("hadamard_block_size", 128)
return cls(is_permuted, hadamard_block_size)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE):

View File

@@ -31,7 +31,7 @@ class W4AFP8Config(QuantConfigBase):
quantization config for weight 4bits and activation fp8
"""
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted) -> None:
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) -> None:
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
@@ -39,6 +39,7 @@ class W4AFP8Config(QuantConfigBase):
self.quant_min_bound = -448
self.quant_round_type = 1
self.is_permuted = is_permuted
self.hadamard_block_size = hadamard_block_size
def name(self) -> str:
return "w4afp8"
@@ -48,7 +49,8 @@ class W4AFP8Config(QuantConfigBase):
weight_scale_dict = config.get("weight_scale_dict", None)
act_scale_dict = config.get("act_scale_dict", None)
is_permuted = config.get("is_permuted", True)
return cls(weight_scale_dict, act_scale_dict, is_permuted)
hadamard_block_size = config.get("hadamard_block_size", 128)
return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE):