[XPU]Support V1 loader in weight_only Model (#4808)

* support v1 loader in wint8

* code style

* update

---------

Co-authored-by: root <root@gajl-bbc-onlinec-com-1498356.gajl.baidu.com>
This commit is contained in:
yinwei
2025-11-05 17:09:11 +08:00
committed by GitHub
parent cc8f5312f5
commit ea1dd0e735
2 changed files with 176 additions and 21 deletions

View File

@@ -29,7 +29,12 @@ from fastdeploy.model_executor.ops.xpu import (
weight_quantize_xpu,
xpu_moe_layer,
)
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
from fastdeploy.model_executor.utils import (
TensorTracker,
default_weight_loader,
free_tensor,
set_weight_attrs,
)
class XPUMoEMethod(MoEMethodBase):
@@ -62,15 +67,17 @@ class XPUMoEMethod(MoEMethodBase):
"""
create weight process.
"""
if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in ["w16a16"]:
if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in [
"w16a16",
"weight_only_int8",
"weight_only_int4",
]:
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
self.down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
@@ -86,18 +93,21 @@ class XPUMoEMethod(MoEMethodBase):
set_weight_attrs(
layer.up_gate_proj_weight,
{
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=False),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=True),
},
)
if layer.with_bias:
layer.up_gate_proj_bias = layer.create_parameter(
shape=[layer.num_experts, layer.moe_intermediate_size * 2],
@@ -128,6 +138,15 @@ class XPUMoEMethod(MoEMethodBase):
"model_format": extra_weight_attrs.get("model_format", ""),
},
)
if self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
self.up_gate_proj_scale_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
]
self.down_proj_scale_shape = [
layer.num_local_experts,
layer.hidden_size,
]
else:
self.up_gate_proj_weight_shape = [
@@ -531,6 +550,87 @@ class XPUWeightOnlyMoEMethod(XPUMoEMethod):
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
getattr(layer, scale_name).set_value(quanted_weight_scale)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
else:
weight_type = "down"
# 1.init shape and type
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
if weight_type == "gate_up":
weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
else:
weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
if self.moe_quant_type in ["weight_only_int4"]:
weight_shape[-1] //= 2
scale_dtype = "float32"
# 2.crate tmp tensor
# weight = paddle.empty(weight_shape, dtype=weight_dtype)
# scale = paddle.empty(scale_shape, dtype=scale_dtype)
# 3.quantize weight
weight_list = []
weight_scale_list = []
for expert_id in range(layer.num_local_experts):
quant_weight, scale = weight_quantize_xpu(
getattr(layer, unquantized_weight_name)[expert_id].transpose([1, 0]), self.moe_quant_type, -1, -1
)
weight_list.append(quant_weight.transpose([1, 0]))
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
free_tensor(getattr(layer, unquantized_weight_name))
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
getattr(layer, scale_name).set_value(quanted_weight_scale)
class XPUW4A8MoEMethod(XPUMoEMethod):
"""

View File

@@ -17,11 +17,17 @@
import paddle
from paddle import nn
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.quantization.weight_only import (
WeightOnlyConfig,
WeightOnlyLinearMethod,
)
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
@@ -41,22 +47,48 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
Create weights for linear layer on XPU
"""
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [layer.weight_shape[1]]
layer.weight_shape.reverse()
if self.quant_config.name() == "weight_only_int4":
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype="float32",
is_bias=False,
)
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if (
isinstance(layer, MergedColumnParallelLinear)
or isinstance(layer, QKVParallelLinear)
or isinstance(layer, MergedReplicatedLinear)
):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True)
),
}
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [layer.weight_shape[1]]
layer.weight_shape.reverse()
if self.quant_config.name() == "weight_only_int4":
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype="float32",
is_bias=False,
)
def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None:
"""
@@ -76,3 +108,26 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
weight_scale_tensor = paddle.concat(weight_scale_tensors, axis=0)
layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.weight_scale.set_value(weight_scale_tensor)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:
return
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(layer.weight, self.quant_config.algo, -1, -1)
free_tensor(layer.weight)
layer.weight = layer.create_parameter(
shape=quanted_weight_tensor.shape[::-1],
dtype="int8",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_tensor.shape,
dtype=weight_scale_tensor.dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.weight_scale.copy_(weight_scale_tensor, False)