support model loading for w4a8 offline quant (#3064)

支持W4A8 EP 对离线量化权重的load
This commit is contained in:
Yuan Xiaolan
2025-07-29 21:54:37 +08:00
committed by GitHub
parent be0a0f2bb2
commit 3214fb5393
4 changed files with 80 additions and 10 deletions

View File

@@ -332,6 +332,66 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
self.moe_quant_type = "w4a8" self.moe_quant_type = "w4a8"
self.pack_num = 2 self.pack_num = 2
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass process prequanted weights.
"""
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
up_gate_proj_in_scale_all_experts = []
up_gate_proj_in_scale = []
down_proj_in_scale = []
if layer.ep_size > 1:
for expert_idx in range(layer.num_experts):
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
up_gate_proj_in_scale_all_experts.append(scale_tensor)
for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
)
down_proj_weight_scale.append(
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
)
up_gate_proj_in_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
)
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0)
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0)
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0)
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
"up_gate_proj_in_scale": up_gate_proj_in_scale,
"down_proj_in_scale": down_proj_in_scale,
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
def create_weights(self, layer: nn.Layer, state_dict): def create_weights(self, layer: nn.Layer, state_dict):
""" """
Paddle cutlass create weight process. Paddle cutlass create weight process.

View File

@@ -19,9 +19,6 @@ from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
CutlassW4A8MoEMethod,
)
from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.worker.experts_manager import RedundantExpertManger from fastdeploy.worker.experts_manager import RedundantExpertManger
@@ -388,12 +385,12 @@ class FusedMoE(nn.Layer):
self.gate_weight.set_value(gate_weight_tensor.astype("float32")) self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
if self.fd_config.model_config.is_quantized: if self.fd_config.model_config.is_quantized:
if isinstance(self.quant_method, CutlassW4A8MoEMethod): if getattr(self.fd_config.quant_config, "is_permuted", False):
self.quant_method.create_weights(self, state_dict)
else:
self.quant_method.process_prequanted_weights(self, state_dict) self.quant_method.process_prequanted_weights(self, state_dict)
else: else:
self.quant_method.create_weights(self, state_dict) self.quant_method.create_weights(self, state_dict)
else:
self.quant_method.create_weights(self, state_dict)
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
""" """

View File

@@ -36,6 +36,7 @@ class MixQuantConfig(QuantConfigBase):
image_moe_quant_type: str = None, image_moe_quant_type: str = None,
is_channel_wise: bool = False, is_channel_wise: bool = False,
has_zero_point: bool = False, has_zero_point: bool = False,
is_permuted: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dense_quant_type = dense_quant_type self.dense_quant_type = dense_quant_type
@@ -50,6 +51,7 @@ class MixQuantConfig(QuantConfigBase):
self.quant_max_bound = 0 self.quant_max_bound = 0
self.quant_min_bound = 0 self.quant_min_bound = 0
self.quant_round_type = 0 self.quant_round_type = 0
self.is_permuted = is_permuted
def name(self) -> str: def name(self) -> str:
return "mix_quant" return "mix_quant"
@@ -63,14 +65,23 @@ class MixQuantConfig(QuantConfigBase):
config.get("image_moe_quant_type", None), config.get("image_moe_quant_type", None),
config.get("is_channel_wise", False), config.get("is_channel_wise", False),
config.get("has_zero_point", False), config.get("has_zero_point", False),
config.get("is_permuted", False),
) )
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
if layer.moe_tag == "Image": if layer.moe_tag == "Image":
return get_quantization_config(self.image_moe_quant_type).from_config({}).get_quant_method(layer) return (
get_quantization_config(self.image_moe_quant_type)
.from_config(layer.fd_config.quant_config)
.get_quant_method(layer)
)
else: else:
return get_quantization_config(self.moe_quant_type).from_config({}).get_quant_method(layer) return (
get_quantization_config(self.moe_quant_type)
.from_config(layer.fd_config.quant_config)
.get_quant_method(layer)
)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
if self.kv_cache_quant_type is not None: if self.kv_cache_quant_type is not None:
return ( return (

View File

@@ -25,15 +25,17 @@ class W4A8Config(QuantConfigBase):
quantization config for weight 4bits and activation 8bits quantization config for weight 4bits and activation 8bits
""" """
def __init__(self) -> None: def __init__(self, is_permuted) -> None:
super().__init__() super().__init__()
self.is_permuted = is_permuted
def name(self) -> str: def name(self) -> str:
return "w4a8" return "w4a8"
@classmethod @classmethod
def from_config(cls, config: dict) -> "W4A8Config": def from_config(cls, config: dict) -> "W4A8Config":
return cls() is_permuted = getattr(config, "is_permuted", False)
return cls(is_permuted)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):