mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -78,8 +78,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] -
|
||||
1) // self.quant_config.weight_block_size[0],
|
||||
@@ -95,8 +95,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
quanted_weight_tensor, weight_block_scale_tensor = (
|
||||
per_block_cast_to_fp8(weight_tensor))
|
||||
layer.linear_weight.copy_(quanted_weight_tensor, False)
|
||||
layer.linear_weight_scale.set_value(weight_block_scale_tensor)
|
||||
layer.weight.copy_(quanted_weight_tensor, False)
|
||||
layer.weight_scale.set_value(weight_block_scale_tensor)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict):
|
||||
"""
|
||||
@@ -106,10 +106,10 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
|
||||
quant_weight = quant_weight.transpose([1, 0]).contiguous()
|
||||
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
weight_scale = weight_scale.transpose([1, 0])
|
||||
layer.linear_weight_scale.set_value(weight_scale)
|
||||
layer.weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
|
||||
@@ -119,9 +119,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(
|
||||
(x, x_scale_tensor),
|
||||
(layer.linear_weight, layer.linear_weight_scale),
|
||||
(layer.weight, layer.weight_scale),
|
||||
linear_out,
|
||||
)
|
||||
if layer.with_bias:
|
||||
linear_out = paddle.add(linear_out, layer.linear_bias)
|
||||
linear_out = paddle.add(linear_out, layer.bias)
|
||||
return linear_out
|
||||
|
@@ -96,7 +96,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
act_scale = get_tensor(state_dict.pop(layer.act_scale_key))
|
||||
|
||||
quant_weight = quant_weight.transpose([1, 0]).contiguous()
|
||||
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
self.act_scale = act_scale.item()
|
||||
self.total_scale = (act_scale * weight_scale).item()
|
||||
@@ -118,7 +118,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
|
||||
linear_out = cutlass_fp8_fp8_half_gemm_fused(
|
||||
fp8_x,
|
||||
layer.linear_weight,
|
||||
layer.weight,
|
||||
transpose_x=False,
|
||||
transpose_y=True,
|
||||
bias=None,
|
||||
|
@@ -63,8 +63,8 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
pass
|
||||
|
||||
@@ -77,16 +77,16 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
scale_dtype="float16",
|
||||
))
|
||||
weight_scale_tensor = paddle.view(weight_scale_tensor, layer._dtype)
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(weight_scale_tensor)
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor)
|
||||
|
||||
def apply(self, layer, x):
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16(
|
||||
x,
|
||||
layer.linear_weight,
|
||||
layer.linear_weight_scale,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
zero_points=None,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix +
|
||||
".weight_scale")
|
||||
/ (self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
|
@@ -69,7 +69,7 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "int8"
|
||||
if self.quant_config.use_smooth_quant:
|
||||
self.smooth_quant_method.create_weights(layer)
|
||||
@@ -101,21 +101,21 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
if self.skip_quant:
|
||||
logger.debug(f"{layer.prefix} skip quant")
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
layer.weight.set_value(weight_tensor)
|
||||
else:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
weight_tensor = paddle.cast(weight_tensor, "int8")
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
layer.weight.set_value(weight_tensor)
|
||||
|
||||
def apply(self, layer, x):
|
||||
if self.skip_quant:
|
||||
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
return linear_out
|
||||
if self.quant_config.use_gemm_dequant:
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.gemm_dequant(
|
||||
x, layer.linear_weight, layer.linear_out_scale, layer._dtype)
|
||||
x, layer.weight, layer.linear_out_scale, layer._dtype)
|
||||
else:
|
||||
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8(
|
||||
linear_out, layer.linear_out_scale, layer._dtype)
|
||||
return linear_out
|
||||
|
@@ -77,12 +77,12 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
return GCUWeightOnlyLinearMethod(self)
|
||||
elif current_platform.is_dcu():
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
DCUTritonWeightOnlyMoEMethod)
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
DCUTritonWeightOnlyMoEMethod
|
||||
return DCUTritonWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
DCUWeightOnlyLinearMethod)
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
DCUWeightOnlyLinearMethod
|
||||
return DCUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
if isinstance(layer, FusedMoE):
|
||||
@@ -152,14 +152,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
def create_weights(self, layer):
|
||||
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
@@ -171,9 +171,9 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
def apply(self, layer, x):
|
||||
linear_out = weight_only_linear(
|
||||
x,
|
||||
weight=layer.linear_weight,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
weight_scale=layer.linear_weight_scale,
|
||||
weight=layer.weight,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_dtype="int8"
|
||||
if self.quant_config.name() == "wint8" else "int4",
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
@@ -204,8 +204,8 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.linear_weight.set_value(quant_weight)
|
||||
layer.linear_weight_scale.set_value(
|
||||
layer.weight.set_value(quant_weight)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
@@ -216,6 +216,6 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
)
|
||||
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
@@ -70,11 +70,11 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
def create_weights(self, layer):
|
||||
"""
|
||||
"""
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||
self.skip_quant = False
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=[1],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
@@ -86,7 +86,7 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
"""
|
||||
if self.skip_quant:
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
layer.weight.set_value(weight_tensor)
|
||||
return
|
||||
if weights.dtype != paddle.float8_e4m3fn:
|
||||
self.use_per_token_if_dynamic = True
|
||||
@@ -95,22 +95,22 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
weight_tensor,
|
||||
use_per_token_if_dynamic=False,
|
||||
)
|
||||
layer.linear_weight.copy_(qweight, False)
|
||||
layer.linear_weight_scale.set_value(weight_scale)
|
||||
layer.weight.copy_(qweight, False)
|
||||
layer.weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
"""
|
||||
"""
|
||||
if self.skip_quant:
|
||||
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
return linear_out
|
||||
if self.use_per_token_if_dynamic:
|
||||
out_type = x.dtype
|
||||
a_q, a_scales = scaled_fp8_quant(
|
||||
x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
linear_out = cutlass_scaled_mm(a_q, layer.linear_weight, a_scales,
|
||||
layer.linear_weight_scale, out_type,
|
||||
layer.linear_bias)
|
||||
linear_out = cutlass_scaled_mm(a_q, layer.weight, a_scales,
|
||||
layer.weight_scale, out_type,
|
||||
layer.bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return linear_out
|
||||
|
Reference in New Issue
Block a user