mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-21 15:49:31 +08:00
refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -19,12 +19,10 @@ from paddle import nn
|
||||
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
|
||||
get_tensor)
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import \
|
||||
QuantMethodBase
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
@@ -36,9 +34,9 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
"up_gate_proj_weight_scale", "down_proj_weight_scale"
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
@@ -49,26 +47,26 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
assert self.quant_method.name() == "wint8"
|
||||
assert ffn1_weights[0].shape == [
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
if self.quant_method.name() == "wint8":
|
||||
max_bound = 127
|
||||
elif self.quant_method.name() == "wint4":
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
@@ -150,10 +148,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.up_gate_proj_weight,
|
||||
intermediate_cache1,
|
||||
None,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -164,17 +162,17 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num * top_k,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[1],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[2],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
#
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
@@ -197,10 +195,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
layer.moe_ffn2_weight,
|
||||
layer.down_proj_weight,
|
||||
intermediate_cache3,
|
||||
None,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -211,16 +209,16 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num * top_k,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[1],
|
||||
stride_bn=layer.down_proj_weight.strides[2],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
|
@@ -16,8 +16,8 @@
|
||||
import paddle
|
||||
from paddle.nn.quant import weight_dequantize
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig, GPUWeightOnlyLinearMethod
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
GPUWeightOnlyLinearMethod, WeightOnlyConfig)
|
||||
|
||||
|
||||
class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
|
||||
@@ -35,12 +35,12 @@ class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
|
||||
|
||||
def apply(self, layer, x):
|
||||
dequant_out = weight_dequantize(
|
||||
x=layer.linear_weight,
|
||||
scale=layer.linear_weight_scale,
|
||||
x=layer.weight,
|
||||
scale=layer.weight_scale,
|
||||
algo=self.quant_config.algo,
|
||||
out_dtype=paddle.get_default_dtype()
|
||||
)
|
||||
linear_out = paddle.matmul(x, dequant_out)
|
||||
if layer.linear_bias is not None:
|
||||
linear_out = paddle.add(linear_out, layer.linear_bias)
|
||||
if layer.bias is not None:
|
||||
linear_out = paddle.add(linear_out, layer.bias)
|
||||
return linear_out
|
||||
|
@@ -50,11 +50,11 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
Paddle gcu create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
||||
# shape [E, K, N] -> [E, N, K]
|
||||
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
@@ -117,16 +117,16 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
|
||||
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
|
||||
up_gate_proj_B_scale = layer.up_gate_proj_weight_scale if enable_quant else None
|
||||
up_gate_proj_B_zeros = layer.up_gate_proj_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
x, # input
|
||||
layer.moe_ffn1_weight, # weight
|
||||
layer.up_gate_proj_weight, # weight
|
||||
intermediate_cache1, # output
|
||||
None, # A_scale
|
||||
ffn1_B_scale, # B_scale
|
||||
ffn1_B_zeros, # B_zp
|
||||
up_gate_proj_B_scale, # B_scale
|
||||
up_gate_proj_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
@@ -154,16 +154,16 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
|
||||
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
|
||||
down_proj_B_scale = layer.down_proj_weight_scale if enable_quant else None
|
||||
down_proj_B_zeros = layer.down_proj_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2, # input
|
||||
layer.moe_ffn2_weight, # weight
|
||||
layer.down_proj_weight, # weight
|
||||
intermediate_cache3, # output
|
||||
None, # A_scale
|
||||
ffn2_B_scale, # B_scale
|
||||
ffn2_B_zeros, # B_zp
|
||||
down_proj_B_scale, # B_scale
|
||||
down_proj_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
@@ -251,7 +251,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
|
||||
|
||||
self.added_qzeros_attrs = [
|
||||
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
|
||||
"up_gate_proj_weight_zeros", "down_proj_weight_zeros"
|
||||
]
|
||||
self.group_size = 64
|
||||
|
||||
@@ -265,41 +265,41 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"""
|
||||
Paddle gcu process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
for i in range(layer.num_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
@@ -310,8 +310,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
|
||||
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
||||
@@ -329,7 +329,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
)
|
||||
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
zeros_name = self.added_qzeros_attrs[idx]
|
||||
@@ -365,8 +365,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
dict_ = dict(shared_dict)
|
||||
|
||||
for k, v in dict_.items():
|
||||
weight_list[k] = v[0].to(ffn1_weights[0].place)
|
||||
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
|
||||
weight_list[k] = v[0].to(up_gate_proj_weights[0].place)
|
||||
weight_scale_list[k] = v[1].to(up_gate_proj_weights[0].place)
|
||||
else:
|
||||
remain_weights_start_idx = 0
|
||||
|
||||
|
@@ -38,14 +38,14 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
|
||||
def create_weights(self, layer):
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
@@ -61,8 +61,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.linear_weight.set_value(quant_weight)
|
||||
layer.linear_weight_scale.set_value(
|
||||
layer.weight.set_value(quant_weight)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
@@ -73,8 +73,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
self.group_size, # group_size
|
||||
)
|
||||
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
@@ -82,8 +82,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
def apply(self, layer, x):
|
||||
linear_out = linear_quant(
|
||||
lhs=x,
|
||||
rhs=layer.linear_weight,
|
||||
scale=layer.linear_weight_scale,
|
||||
rhs=layer.weight,
|
||||
scale=layer.weight_scale,
|
||||
bias=None,
|
||||
group_size=self.group_size,
|
||||
)
|
||||
|
@@ -37,13 +37,13 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
Create weights for linear layer on XPU
|
||||
"""
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
layer.linear_weight_shape.reverse()
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "weight_only_int4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
@@ -55,6 +55,6 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
|
||||
weight, self.quant_config.algo, -1, -1)
|
||||
layer.linear_weight.set_value(
|
||||
layer.weight.set_value(
|
||||
paddle.transpose(quanted_weight_tensor, [1, 0]))
|
||||
layer.linear_weight_scale.set_value(weight_scale_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor)
|
||||
|
Reference in New Issue
Block a user