mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU]Support V1 loader in weight_only Model (#4808)
* support v1 loader in wint8 * code style * update --------- Co-authored-by: root <root@gajl-bbc-onlinec-com-1498356.gajl.baidu.com>
This commit is contained in:
@@ -29,7 +29,12 @@ from fastdeploy.model_executor.ops.xpu import (
|
||||
weight_quantize_xpu,
|
||||
xpu_moe_layer,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
|
||||
from fastdeploy.model_executor.utils import (
|
||||
TensorTracker,
|
||||
default_weight_loader,
|
||||
free_tensor,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
|
||||
class XPUMoEMethod(MoEMethodBase):
|
||||
@@ -62,15 +67,17 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
create weight process.
|
||||
"""
|
||||
if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in ["w16a16"]:
|
||||
if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in [
|
||||
"w16a16",
|
||||
"weight_only_int8",
|
||||
"weight_only_int4",
|
||||
]:
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
|
||||
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
@@ -86,18 +93,21 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
|
||||
if layer.with_bias:
|
||||
layer.up_gate_proj_bias = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.moe_intermediate_size * 2],
|
||||
@@ -128,6 +138,15 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||
},
|
||||
)
|
||||
if self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
|
||||
self.up_gate_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
self.down_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
]
|
||||
|
||||
else:
|
||||
self.up_gate_proj_weight_shape = [
|
||||
@@ -531,6 +550,87 @@ class XPUWeightOnlyMoEMethod(XPUMoEMethod):
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
else:
|
||||
weight_type = "down"
|
||||
|
||||
# 1.init shape and type
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
|
||||
if weight_type == "gate_up":
|
||||
weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
else:
|
||||
weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
weight_dtype = "int8"
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
||||
if self.moe_quant_type in ["weight_only_int4"]:
|
||||
weight_shape[-1] //= 2
|
||||
scale_dtype = "float32"
|
||||
|
||||
# 2.crate tmp tensor
|
||||
|
||||
# weight = paddle.empty(weight_shape, dtype=weight_dtype)
|
||||
# scale = paddle.empty(scale_shape, dtype=scale_dtype)
|
||||
|
||||
# 3.quantize weight
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for expert_id in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize_xpu(
|
||||
getattr(layer, unquantized_weight_name)[expert_id].transpose([1, 0]), self.moe_quant_type, -1, -1
|
||||
)
|
||||
weight_list.append(quant_weight.transpose([1, 0]))
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
|
||||
free_tensor(getattr(layer, unquantized_weight_name))
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_shape,
|
||||
dtype=weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype=scale_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
|
||||
class XPUW4A8MoEMethod(XPUMoEMethod):
|
||||
"""
|
||||
|
||||
@@ -17,11 +17,17 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
WeightOnlyConfig,
|
||||
WeightOnlyLinearMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
|
||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||
|
||||
|
||||
class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
@@ -41,22 +47,48 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
Create weights for linear layer on XPU
|
||||
"""
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "weight_only_int4":
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
quant_attrs = extra_weight_attrs
|
||||
if (
|
||||
isinstance(layer, MergedColumnParallelLinear)
|
||||
or isinstance(layer, QKVParallelLinear)
|
||||
or isinstance(layer, MergedReplicatedLinear)
|
||||
):
|
||||
quant_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(
|
||||
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True)
|
||||
),
|
||||
}
|
||||
set_weight_attrs(
|
||||
layer.weight,
|
||||
quant_attrs,
|
||||
)
|
||||
else:
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "weight_only_int4":
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None:
|
||||
"""
|
||||
@@ -76,3 +108,26 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
weight_scale_tensor = paddle.concat(weight_scale_tensors, axis=0)
|
||||
layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0]))
|
||||
layer.weight_scale.set_value(weight_scale_tensor)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(layer.weight, self.quant_config.algo, -1, -1)
|
||||
|
||||
free_tensor(layer.weight)
|
||||
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=quanted_weight_tensor.shape[::-1],
|
||||
dtype="int8",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_tensor.shape,
|
||||
dtype=weight_scale_tensor.dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0]))
|
||||
layer.weight_scale.copy_(weight_scale_tensor, False)
|
||||
|
||||
Reference in New Issue
Block a user