mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-23 16:44:22 +08:00
refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -19,12 +19,10 @@ from paddle import nn
|
||||
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
|
||||
get_tensor)
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import \
|
||||
QuantMethodBase
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
@@ -36,9 +34,9 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
"up_gate_proj_weight_scale", "down_proj_weight_scale"
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
@@ -49,26 +47,26 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
assert self.quant_method.name() == "wint8"
|
||||
assert ffn1_weights[0].shape == [
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
if self.quant_method.name() == "wint8":
|
||||
max_bound = 127
|
||||
elif self.quant_method.name() == "wint4":
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
@@ -150,10 +148,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.up_gate_proj_weight,
|
||||
intermediate_cache1,
|
||||
None,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -164,17 +162,17 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num * top_k,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[1],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[2],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
#
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
@@ -197,10 +195,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
layer.moe_ffn2_weight,
|
||||
layer.down_proj_weight,
|
||||
intermediate_cache3,
|
||||
None,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -211,16 +209,16 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num * top_k,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[1],
|
||||
stride_bn=layer.down_proj_weight.strides[2],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
|
@@ -16,8 +16,8 @@
|
||||
import paddle
|
||||
from paddle.nn.quant import weight_dequantize
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig, GPUWeightOnlyLinearMethod
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
GPUWeightOnlyLinearMethod, WeightOnlyConfig)
|
||||
|
||||
|
||||
class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
|
||||
@@ -35,12 +35,12 @@ class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
|
||||
|
||||
def apply(self, layer, x):
|
||||
dequant_out = weight_dequantize(
|
||||
x=layer.linear_weight,
|
||||
scale=layer.linear_weight_scale,
|
||||
x=layer.weight,
|
||||
scale=layer.weight_scale,
|
||||
algo=self.quant_config.algo,
|
||||
out_dtype=paddle.get_default_dtype()
|
||||
)
|
||||
linear_out = paddle.matmul(x, dequant_out)
|
||||
if layer.linear_bias is not None:
|
||||
linear_out = paddle.add(linear_out, layer.linear_bias)
|
||||
if layer.bias is not None:
|
||||
linear_out = paddle.add(linear_out, layer.bias)
|
||||
return linear_out
|
||||
|
@@ -50,11 +50,11 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
Paddle gcu create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
||||
# shape [E, K, N] -> [E, N, K]
|
||||
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
@@ -117,16 +117,16 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
|
||||
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
|
||||
up_gate_proj_B_scale = layer.up_gate_proj_weight_scale if enable_quant else None
|
||||
up_gate_proj_B_zeros = layer.up_gate_proj_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
x, # input
|
||||
layer.moe_ffn1_weight, # weight
|
||||
layer.up_gate_proj_weight, # weight
|
||||
intermediate_cache1, # output
|
||||
None, # A_scale
|
||||
ffn1_B_scale, # B_scale
|
||||
ffn1_B_zeros, # B_zp
|
||||
up_gate_proj_B_scale, # B_scale
|
||||
up_gate_proj_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
@@ -154,16 +154,16 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
|
||||
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
|
||||
down_proj_B_scale = layer.down_proj_weight_scale if enable_quant else None
|
||||
down_proj_B_zeros = layer.down_proj_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2, # input
|
||||
layer.moe_ffn2_weight, # weight
|
||||
layer.down_proj_weight, # weight
|
||||
intermediate_cache3, # output
|
||||
None, # A_scale
|
||||
ffn2_B_scale, # B_scale
|
||||
ffn2_B_zeros, # B_zp
|
||||
down_proj_B_scale, # B_scale
|
||||
down_proj_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
@@ -251,7 +251,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
|
||||
|
||||
self.added_qzeros_attrs = [
|
||||
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
|
||||
"up_gate_proj_weight_zeros", "down_proj_weight_zeros"
|
||||
]
|
||||
self.group_size = 64
|
||||
|
||||
@@ -265,41 +265,41 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"""
|
||||
Paddle gcu process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
for i in range(layer.num_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
@@ -310,8 +310,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
|
||||
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
||||
@@ -329,7 +329,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
)
|
||||
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
zeros_name = self.added_qzeros_attrs[idx]
|
||||
@@ -365,8 +365,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
dict_ = dict(shared_dict)
|
||||
|
||||
for k, v in dict_.items():
|
||||
weight_list[k] = v[0].to(ffn1_weights[0].place)
|
||||
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
|
||||
weight_list[k] = v[0].to(up_gate_proj_weights[0].place)
|
||||
weight_scale_list[k] = v[1].to(up_gate_proj_weights[0].place)
|
||||
else:
|
||||
remain_weights_start_idx = 0
|
||||
|
||||
|
@@ -38,14 +38,14 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
|
||||
def create_weights(self, layer):
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
@@ -61,8 +61,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.linear_weight.set_value(quant_weight)
|
||||
layer.linear_weight_scale.set_value(
|
||||
layer.weight.set_value(quant_weight)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
@@ -73,8 +73,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
self.group_size, # group_size
|
||||
)
|
||||
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
@@ -82,8 +82,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
def apply(self, layer, x):
|
||||
linear_out = linear_quant(
|
||||
lhs=x,
|
||||
rhs=layer.linear_weight,
|
||||
scale=layer.linear_weight_scale,
|
||||
rhs=layer.weight,
|
||||
scale=layer.weight_scale,
|
||||
bias=None,
|
||||
group_size=self.group_size,
|
||||
)
|
||||
|
@@ -37,13 +37,13 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
Create weights for linear layer on XPU
|
||||
"""
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
layer.linear_weight_shape.reverse()
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "weight_only_int4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
@@ -55,6 +55,6 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
|
||||
weight, self.quant_config.algo, -1, -1)
|
||||
layer.linear_weight.set_value(
|
||||
layer.weight.set_value(
|
||||
paddle.transpose(quanted_weight_tensor, [1, 0]))
|
||||
layer.linear_weight_scale.set_value(weight_scale_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor)
|
||||
|
@@ -68,13 +68,13 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
self.params_dtype: str = params_dtype
|
||||
|
||||
if self.use_ep:
|
||||
self.word_embeddings = nn.Embedding(
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
)
|
||||
else:
|
||||
if not self.column_cut:
|
||||
self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
@@ -85,13 +85,13 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
)
|
||||
else:
|
||||
# column cut embedding
|
||||
self.word_embeddings = nn.Embedding(
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
|
||||
self.word_embeddings.weight.is_distributed = True
|
||||
self.word_embeddings.weight.split_axis = 1
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
|
||||
if not self.use_rope:
|
||||
self.position_embeddings = nn.Embedding(
|
||||
@@ -112,13 +112,12 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
a = state_dict[self.prefix + ".weight"]
|
||||
if self.tie_word_embeddings:
|
||||
self.word_embeddings.weight.set_value(
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict[self.prefix + ".weight"]).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
self.word_embeddings.weight.set_value(
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
|
||||
paddle.get_default_dtype()))
|
||||
|
||||
@@ -134,10 +133,10 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Tensor: Embedded tensor representation of the input IDs.
|
||||
"""
|
||||
if self.use_ep:
|
||||
input_embedings = self.word_embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
else:
|
||||
if self.column_cut:
|
||||
input_embedings = self.word_embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
inputs_embeds_temp = []
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
@@ -148,6 +147,6 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
else:
|
||||
input_embedings = self.word_embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
|
||||
return input_embedings
|
||||
|
@@ -14,16 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
from paddle.distributed.fleet.meta_parallel import (
|
||||
ColumnParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from paddle.distributed.fleet.meta_parallel import (ColumnParallelLinear,
|
||||
VocabParallelEmbedding)
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -130,7 +127,7 @@ class HydraHead(nn.Layer):
|
||||
]
|
||||
)
|
||||
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
self.embeddings = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
@@ -170,8 +167,8 @@ class HydraHead(nn.Layer):
|
||||
get_tensor(state_dict.pop(f"1.{hydra_head_idx}.weight"))
|
||||
)
|
||||
|
||||
self.word_embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop("word_embeddings.weight"))
|
||||
self.embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop("embeddings.weight"))
|
||||
)
|
||||
|
||||
def set_state_dict(self, state_dict):
|
||||
@@ -183,7 +180,7 @@ class HydraHead(nn.Layer):
|
||||
"""
|
||||
is_custom = True
|
||||
for key in state_dict.keys():
|
||||
if key != "word_embeddings.weight" and (
|
||||
if key != "embeddings.weight" and (
|
||||
"hydra_mlp" in key or "hydra_head" in key
|
||||
):
|
||||
is_custom = False
|
||||
@@ -207,7 +204,7 @@ class HydraHead(nn.Layer):
|
||||
hidden_states: [batch_size, hidden_size] The hidden_states of the last accept_tokens
|
||||
"""
|
||||
hydra_inputs = [hidden_states]
|
||||
input_embeds = self.word_embeddings(input_ids)
|
||||
input_embeds = self.embeddings(input_ids)
|
||||
for hydra_head_idx in range(self.hydra_num_heads):
|
||||
hydra_inputs.append(input_embeds)
|
||||
head_input = paddle.concat(hydra_inputs, axis=-1)
|
||||
@@ -217,4 +214,4 @@ class HydraHead(nn.Layer):
|
||||
_, topk_tokens = paddle.topk(probs, k=1, axis=-1)
|
||||
next_tokens[:, 1 + hydra_head_idx : 2 + hydra_head_idx] = topk_tokens[:]
|
||||
|
||||
input_embeds = self.word_embeddings(next_tokens[:, 1 + hydra_head_idx])
|
||||
input_embeds = self.embeddings(next_tokens[:, 1 + hydra_head_idx])
|
||||
|
@@ -79,7 +79,7 @@ class LinearBase(nn.Layer):
|
||||
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight_shape = [
|
||||
self.weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
@@ -96,16 +96,16 @@ class LinearBase(nn.Layer):
|
||||
"""
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
self.weight = self.create_parameter(
|
||||
shape=self.weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
self.linear_bias = None
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
self.linear_bias = self.create_parameter(
|
||||
self.bias = self.create_parameter(
|
||||
shape=[self.output_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
@@ -136,7 +136,7 @@ class LinearBase(nn.Layer):
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.linear_weight.set_value(weight_tensor)
|
||||
self.weight.set_value(weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
@@ -157,7 +157,7 @@ class LinearBase(nn.Layer):
|
||||
if self.with_bias:
|
||||
bias_tensor = paddle.to_tensor(
|
||||
get_tensor(state_dict.pop(self.bias_key)))
|
||||
self.linear_bias.set_value(bias_tensor)
|
||||
self.bias.set_value(bias_tensor)
|
||||
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -175,9 +175,9 @@ class LinearBase(nn.Layer):
|
||||
if self.fd_config.quant_config:
|
||||
linear_out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
linear_out = paddle.matmul(x, self.linear_weight)
|
||||
linear_out = paddle.matmul(x, self.weight)
|
||||
if self.with_bias:
|
||||
linear_out = paddle.add(linear_out, self.linear_bias)
|
||||
linear_out = paddle.add(linear_out, self.bias)
|
||||
|
||||
return linear_out
|
||||
|
||||
@@ -219,7 +219,7 @@ class ReplicatedLinear(LinearBase):
|
||||
skip_quant=skip_quant)
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.linear_weight_shape = [
|
||||
self.weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
@@ -272,7 +272,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
output_size,
|
||||
self.nranks) # Split the output_size using TP inference.
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.linear_weight_shape = [
|
||||
self.weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
@@ -286,26 +286,26 @@ class ColumnParallelLinear(LinearBase):
|
||||
"""
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
self.weight = self.create_parameter(
|
||||
shape=self.weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.linear_weight, split_axis=1)
|
||||
_set_var_distributed(self.weight, split_axis=1)
|
||||
|
||||
self.linear_bias = None
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
self.linear_bias = self.create_parameter(
|
||||
self.bias = self.create_parameter(
|
||||
shape=[self.output_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.linear_bias, split_axis=1)
|
||||
_set_var_distributed(self.bias, split_axis=1)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
@@ -333,7 +333,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the fused ffn1 Linear layer with given parameters.
|
||||
Initialize the fused up_gate_proj Linear layer with given parameters.
|
||||
|
||||
Args:
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
@@ -443,7 +443,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
q_tensor = get_tensor(state_dict.pop(q_weight_key))
|
||||
k_tensor = get_tensor(state_dict.pop(k_weight_key))
|
||||
v_tensor = get_tensor(state_dict.pop(v_weight_key))
|
||||
|
||||
|
||||
if self.kv_num_heads < self.nranks:
|
||||
sharedkv_index = (self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads) // self.nranks
|
||||
sharedkv_start = sharedkv_index * self.head_dim
|
||||
@@ -462,7 +462,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.linear_weight.set_value(weight_tensor)
|
||||
self.weight.set_value(weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
@@ -485,7 +485,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
if self.bias_key in state_dict.keys():
|
||||
bias_tensor = paddle.to_tensor(
|
||||
get_tensor(state_dict.pop(self.bias_key)))
|
||||
self.linear_bias.set_value(bias_tensor)
|
||||
self.bias.set_value(bias_tensor)
|
||||
else:
|
||||
q_bias_key = self.bias_key.replace("qkv_proj", "q_proj")
|
||||
k_bias_key = self.bias_key.replace("qkv_proj", "k_proj")
|
||||
@@ -494,7 +494,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
k_bias = get_tensor(state_dict.pop(k_bias_key))
|
||||
v_bias = get_tensor(state_dict.pop(v_bias_key))
|
||||
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
|
||||
self.linear_bias.set_value(qkv_bias)
|
||||
self.bias.set_value(qkv_bias)
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
@@ -554,7 +554,7 @@ class RowParallelLinear(LinearBase):
|
||||
self.input_size = divide(input_size, self.nranks)
|
||||
self.output_size = output_size
|
||||
|
||||
self.linear_weight_shape = [
|
||||
self.weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
@@ -574,16 +574,16 @@ class RowParallelLinear(LinearBase):
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
self.weight = self.create_parameter(
|
||||
shape=self.weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
self.linear_bias = None
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
self.linear_bias = self.create_parameter(
|
||||
self.bias = self.create_parameter(
|
||||
shape=[self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
@@ -591,7 +591,7 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
if self.nranks > 0:
|
||||
# row parallel
|
||||
_set_var_distributed(self.linear_weight, split_axis=0)
|
||||
_set_var_distributed(self.weight, split_axis=0)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
@@ -601,7 +601,7 @@ class RowParallelLinear(LinearBase):
|
||||
if self.fd_config.quant_config:
|
||||
out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
out = paddle.matmul(x, self.linear_weight)
|
||||
out = paddle.matmul(x, self.weight)
|
||||
|
||||
if self.reduce_results and self.nranks > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
@@ -52,11 +52,11 @@ class ParallelLMHead(nn.Layer):
|
||||
with_bias (bool): whether to have bias. Default: False.
|
||||
"""
|
||||
super(ParallelLMHead, self).__init__()
|
||||
self.linear_weight_key: str = prefix + ".weight"
|
||||
self.weight_key: str = prefix + ".weight"
|
||||
if with_bias:
|
||||
self.linear_bias_key: Optional[str] = prefix + ".bias"
|
||||
self.bias_key: Optional[str] = prefix + ".bias"
|
||||
else:
|
||||
self.linear_bias_key: Optional[str] = None
|
||||
self.bias_key: Optional[str] = None
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.column_cut = True
|
||||
|
||||
@@ -74,26 +74,26 @@ class ParallelLMHead(nn.Layer):
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.out_linear = ColumnParallelLinear(
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
else:
|
||||
self.out_linear = RowParallelLinear(
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
@@ -109,25 +109,25 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
if self.tie_word_embeddings:
|
||||
self.out_linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
|
||||
self.linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(
|
||||
paddle.get_default_dtype()).transpose([1, 0]))
|
||||
else:
|
||||
weight_tensor = get_tensor(
|
||||
state_dict.pop(self.linear_weight_key)).astype(
|
||||
state_dict.pop(self.weight_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
if self.out_linear.weight.shape != weight_tensor.shape:
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.out_linear.weight.set_value(weight_tensor)
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.linear_bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
self.out_linear.bias.set_value(bias)
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -143,5 +143,5 @@ class ParallelLMHead(nn.Layer):
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = self.out_linear(logits)
|
||||
logits = self.linear(logits)
|
||||
return logits
|
||||
|
@@ -34,9 +34,9 @@ class MoEMethodBase(QuantMethodBase):
|
||||
self.moe_quant_type = "w16a16"
|
||||
else:
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
"up_gate_proj_weight_scale", "down_proj_weight_scale"
|
||||
]
|
||||
self.pack_num = 1
|
||||
|
||||
@@ -63,14 +63,14 @@ class MoEMethodBase(QuantMethodBase):
|
||||
"""
|
||||
pass
|
||||
|
||||
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
|
||||
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
|
||||
"""
|
||||
check layer is valid for this method
|
||||
"""
|
||||
assert ffn1_weights[0].shape == [
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size // self.pack_num, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size // self.pack_num, layer.hidden_size
|
||||
]
|
||||
|
||||
|
@@ -31,7 +31,8 @@ if current_platform.is_cuda() and not current_platform.is_dcu():
|
||||
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
|
||||
moe_expert_reduce, noaux_tc)
|
||||
elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import moe_expert_dispatch, moe_expert_reduce
|
||||
from fastdeploy.model_executor.ops.iluvatar import (moe_expert_dispatch,
|
||||
moe_expert_reduce)
|
||||
|
||||
|
||||
# used for deepseek_v3
|
||||
@@ -65,11 +66,11 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
@@ -95,15 +96,15 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None,
|
||||
(layer.moe_ffn1_weight_scale if hasattr(
|
||||
layer, "moe_ffn1_weight_scale") else None),
|
||||
(layer.moe_ffn2_weight_scale if hasattr(
|
||||
layer, "moe_ffn2_weight_scale") else None),
|
||||
(layer.moe_ffn2_in_scale
|
||||
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||
(layer.up_gate_proj_weight_scale if hasattr(
|
||||
layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale if hasattr(
|
||||
layer, "down_proj_weight_scale") else None),
|
||||
(layer.down_proj_in_scale
|
||||
if hasattr(layer, "down_proj_in_scale") else None),
|
||||
expert_idx_per_token,
|
||||
self.moe_quant_type,
|
||||
used_in_ep_low_latency,
|
||||
@@ -111,15 +112,15 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None,
|
||||
(layer.moe_ffn1_weight_scale
|
||||
if hasattr(layer, "moe_ffn1_weight_scale") else None),
|
||||
(layer.moe_ffn2_weight_scale
|
||||
if hasattr(layer, "moe_ffn2_weight_scale") else None),
|
||||
(layer.moe_ffn2_in_scale
|
||||
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||
(layer.up_gate_proj_weight_scale
|
||||
if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale
|
||||
if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
(layer.down_proj_in_scale
|
||||
if hasattr(layer, "down_proj_in_scale") else None),
|
||||
expert_idx_per_token,
|
||||
self.moe_quant_type,
|
||||
used_in_ep_low_latency,
|
||||
@@ -163,8 +164,8 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
(self.moe_ffn1_in_scale
|
||||
if hasattr(self, "moe_ffn1_in_scale") else None),
|
||||
(self.up_gate_proj_in_scale
|
||||
if hasattr(self, "up_gate_proj_in_scale") else None),
|
||||
recv_num_tokens_per_expert_list,
|
||||
token_all_num,
|
||||
self.moe_quant_type,
|
||||
@@ -186,7 +187,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
dst_weights,
|
||||
permute_indices_per_token,
|
||||
dst_indices,
|
||||
None, # moe_ffn2_bias,
|
||||
None, # down_proj_bias,
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)[0]
|
||||
@@ -256,7 +257,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
x,
|
||||
gate_out,
|
||||
None, # Use layer.gate_correction_bias in get_moe_scores.
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
@@ -274,7 +275,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
@@ -323,9 +324,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
weight_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
@@ -366,26 +367,26 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
create_and_set_parameter(layer, name, processed_weight_scale)
|
||||
|
||||
# 1. Init scale containers and maps
|
||||
moe_ffn1_weight_scales = []
|
||||
moe_ffn2_weight_scales = []
|
||||
moe_ffn1_in_scales = []
|
||||
moe_ffn2_in_scales = []
|
||||
up_gate_proj_weight_scales = []
|
||||
down_proj_weight_scales = []
|
||||
up_gate_proj_in_scales = []
|
||||
down_proj_in_scales = []
|
||||
|
||||
scale_weight_map = {
|
||||
"moe_ffn1_weight_scale": moe_ffn1_weight_scales,
|
||||
"moe_ffn2_weight_scale": moe_ffn2_weight_scales,
|
||||
"moe_ffn1_in_scale": moe_ffn1_in_scales,
|
||||
"moe_ffn2_in_scale": moe_ffn2_in_scales,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scales,
|
||||
"down_proj_weight_scale": down_proj_weight_scales,
|
||||
"up_gate_proj_in_scale": up_gate_proj_in_scales,
|
||||
"down_proj_in_scale": down_proj_in_scales,
|
||||
}
|
||||
scale_key_map = {
|
||||
"moe_ffn1_weight_scale":
|
||||
weight_key_map.get("ffn1_expert_weight_scale_key", None),
|
||||
"moe_ffn2_weight_scale":
|
||||
weight_key_map.get("ffn2_expert_weight_scale_key", None),
|
||||
"moe_ffn1_in_scale":
|
||||
weight_key_map.get("ffn1_expert_in_scale_key", None),
|
||||
"moe_ffn2_in_scale":
|
||||
weight_key_map.get("ffn2_expert_in_scale_key", None),
|
||||
"up_gate_proj_weight_scale":
|
||||
weight_key_map.get("up_gate_proj_expert_weight_scale_key", None),
|
||||
"down_proj_weight_scale":
|
||||
weight_key_map.get("down_proj_expert_weight_scale_key", None),
|
||||
"up_gate_proj_in_scale":
|
||||
weight_key_map.get("up_gate_proj_expert_in_scale_key", None),
|
||||
"down_proj_in_scale":
|
||||
weight_key_map.get("down_proj_expert_in_scale_key", None),
|
||||
}
|
||||
for name, value in scale_key_map.items():
|
||||
if value is None:
|
||||
@@ -404,13 +405,13 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
|
||||
# 3. Process scale tensor and set to layer
|
||||
in_scales = []
|
||||
for in_scale_name in ["moe_ffn1_in_scale", "moe_ffn2_in_scale"]:
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
in_scales.append(
|
||||
_process_in_scale(in_scale_name,
|
||||
scale_weight_map[in_scale_name]))
|
||||
|
||||
for i, weight_scale_name in enumerate(
|
||||
["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]):
|
||||
["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
||||
_process_weight_scale(weight_scale_name,
|
||||
scale_weight_map[weight_scale_name],
|
||||
in_scales[i])
|
||||
@@ -431,41 +432,41 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
for i in range(layer.num_local_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
@@ -474,10 +475,10 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
|
@@ -39,11 +39,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
deepgemm create weight process.
|
||||
"""
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
@@ -70,41 +70,41 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
for i in range(layer.num_local_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
@@ -143,10 +143,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
(recv_x, recv_x_scale) = recv_x
|
||||
|
||||
|
||||
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist())
|
||||
|
||||
|
||||
(
|
||||
permute_input,
|
||||
permute_scale,
|
||||
@@ -171,21 +171,21 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||
permute_scale = permute_scale.transpose([1, 0])
|
||||
|
||||
# ffn1
|
||||
# up_gate_proj
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer.moe_ffn1_weight.shape[1]),
|
||||
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale),
|
||||
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
# swiglu
|
||||
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
|
||||
|
||||
# ffn2
|
||||
# down_proj
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, self.quant_config.weight_block_size[0])
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose(
|
||||
@@ -193,11 +193,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]),
|
||||
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
|
||||
dtype=paddle.bfloat16)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale),
|
||||
(layer.down_proj_weight, layer.down_proj_weight_scale),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
@@ -207,7 +207,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
dst_weights,
|
||||
permute_indices_per_token,
|
||||
dst_indices,
|
||||
None, # moe_ffn2_bias
|
||||
None, # down_proj_bias
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)[0]
|
||||
@@ -237,7 +237,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# 3. Compute ffn
|
||||
assert isinstance(permute_input, tuple)
|
||||
ffn1_out = paddle.empty(
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[
|
||||
layer.num_local_experts,
|
||||
layer.ep_size *
|
||||
@@ -261,16 +261,16 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
permute_input,
|
||||
(
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
),
|
||||
ffn1_out,
|
||||
up_gate_proj_out,
|
||||
token_nums_per_expert,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
act_out = fastdeploy.model_executor.ops.gpu.group_swiglu_with_masked(
|
||||
ffn1_out, token_nums_per_expert)
|
||||
up_gate_proj_out, token_nums_per_expert)
|
||||
|
||||
act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant(
|
||||
act_out, token_nums_per_expert,
|
||||
@@ -279,8 +279,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
(act_out_fp8, scale),
|
||||
(
|
||||
layer.moe_ffn2_weight,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.down_proj_weight,
|
||||
layer.down_proj_weight_scale,
|
||||
),
|
||||
ffn_out,
|
||||
token_nums_per_expert,
|
||||
@@ -339,21 +339,21 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||
permute_scale = permute_scale.transpose([1, 0])
|
||||
|
||||
# ffn1
|
||||
# up_gate_proj
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer.moe_ffn1_weight.shape[1]),
|
||||
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale),
|
||||
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
# swiglu
|
||||
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)
|
||||
|
||||
# ffn2
|
||||
# down_proj
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, self.quant_config.weight_block_size[0])
|
||||
|
||||
@@ -362,11 +362,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]),
|
||||
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
|
||||
dtype=paddle.bfloat16)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale),
|
||||
(layer.down_proj_weight, layer.down_proj_weight_scale),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
|
@@ -103,9 +103,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
Marlin Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
"up_gate_proj_weight_scale", "down_proj_weight_scale"
|
||||
]
|
||||
self.added_zeros_attrs = ["zeros0", "zeros1"]
|
||||
|
||||
@@ -113,22 +113,22 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Marlin MoE create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert ffn1_weights[0].shape == [
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
@@ -221,8 +221,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ffn_out = MoeWna16MarlinGemmApi(
|
||||
x,
|
||||
c_or_none=None,
|
||||
b_q_weight=layer.moe_ffn1_weight,
|
||||
b_scales=layer.moe_ffn1_weight_scale,
|
||||
b_q_weight=layer.up_gate_proj_weight,
|
||||
b_scales=layer.up_gate_proj_weight_scale,
|
||||
global_scale_or_none=None,
|
||||
b_zeros_or_none=None,
|
||||
g_idx_or_none=None,
|
||||
@@ -250,8 +250,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ffn_out = MoeWna16MarlinGemmApi(
|
||||
swiglu_out,
|
||||
c_or_none=None,
|
||||
b_q_weight=layer.moe_ffn2_weight,
|
||||
b_scales=layer.moe_ffn2_weight_scale,
|
||||
b_q_weight=layer.down_proj_weight,
|
||||
b_scales=layer.down_proj_weight_scale,
|
||||
global_scale_or_none=None,
|
||||
b_zeros_or_none=None,
|
||||
g_idx_or_none=None,
|
||||
|
@@ -30,7 +30,7 @@ try:
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
except:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -44,9 +44,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
"up_gate_proj_weight_scale", "down_proj_weight_scale"
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
@@ -57,30 +57,30 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
|
||||
assert algo == "wint8"
|
||||
|
||||
assert ffn1_weights[0].shape == [
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
if algo == "wint8":
|
||||
max_bound = 127
|
||||
elif algo == "wint4":
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
@@ -130,7 +130,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
ffn1_out = paddle.empty(
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
@@ -150,10 +150,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
layer.moe_ffn1_weight,
|
||||
ffn1_out,
|
||||
layer.up_gate_proj_weight,
|
||||
up_gate_proj_out,
|
||||
None,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -164,17 +164,17 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
K=hidden_size,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||
stride_cm=ffn1_out.strides[0],
|
||||
stride_cn=ffn1_out.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[1],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[2],
|
||||
stride_cm=up_gate_proj_out.strides[0],
|
||||
stride_cn=up_gate_proj_out.strides[1],
|
||||
#
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
@@ -190,10 +190,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
ffn2_input = paddle.incubate.nn.functional.swiglu(
|
||||
ffn1_out)
|
||||
down_proj_input = paddle.incubate.nn.functional.swiglu(
|
||||
up_gate_proj_out)
|
||||
|
||||
ffn2_out = paddle.empty(
|
||||
down_proj_out = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
@@ -202,11 +202,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
fused_moe_kernel_paddle[grid](
|
||||
ffn2_input,
|
||||
layer.moe_ffn2_weight,
|
||||
ffn2_out,
|
||||
down_proj_input,
|
||||
layer.down_proj_weight,
|
||||
down_proj_out,
|
||||
None,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -215,18 +215,18 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=ffn2_input.strides[0],
|
||||
stride_ak=ffn2_input.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||
stride_cm=ffn2_out.strides[0],
|
||||
stride_cn=ffn2_out.strides[1],
|
||||
stride_am=down_proj_input.strides[0],
|
||||
stride_ak=down_proj_input.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[1],
|
||||
stride_bn=layer.down_proj_weight.strides[2],
|
||||
stride_cm=down_proj_out.strides[0],
|
||||
stride_cn=down_proj_out.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
@@ -242,8 +242,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
ffn2_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = ffn2_out.sum(axis=1)
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
return out
|
||||
|
||||
|
||||
@@ -261,20 +261,20 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
|
||||
ffn1_tensor, ffn2_tensor = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert ffn1_tensor[0].shape == [
|
||||
up_gate_proj_tensor, down_proj_tensor = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert up_gate_proj_tensor[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_tensor[0].shape == [
|
||||
assert down_proj_tensor[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
ffn1_tensor = paddle.stack(ffn1_tensor, axis=0).view(paddle.float8_e4m3fn)
|
||||
ffn2_tensor = paddle.stack(ffn2_tensor, axis=0).view(paddle.float8_e4m3fn)
|
||||
up_gate_proj_tensor = paddle.stack(up_gate_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
|
||||
down_proj_tensor = paddle.stack(down_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
|
||||
|
||||
added_wfp8afp8_attrs = [
|
||||
"moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale",
|
||||
"moe_ffn2_weight_scale", "moe_ffn1_in_scale", "moe_ffn2_in_scale"
|
||||
"up_gate_proj_weight", "down_proj_weight", "up_gate_proj_weight_scale",
|
||||
"down_proj_weight_scale", "up_gate_proj_in_scale", "down_proj_in_scale"
|
||||
]
|
||||
|
||||
def _extract_scale_tensor(key_template):
|
||||
@@ -285,18 +285,18 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
return paddle.concat(result).cast("float32")
|
||||
|
||||
weight_key_map = layer.weight_key_map
|
||||
moe_ffn1_weight_scale = _extract_scale_tensor(
|
||||
weight_key_map["ffn1_expert_weight_scale_key"])
|
||||
moe_ffn2_weight_scale = _extract_scale_tensor(
|
||||
weight_key_map["ffn2_expert_weight_scale_key"])
|
||||
moe_ffn1_in_scale = _extract_scale_tensor(
|
||||
weight_key_map["ffn1_expert_in_scale_key"])
|
||||
moe_ffn2_in_scale = _extract_scale_tensor(
|
||||
weight_key_map["ffn2_expert_in_scale_key"])
|
||||
up_gate_proj_weight_scale = _extract_scale_tensor(
|
||||
weight_key_map["up_gate_proj_expert_weight_scale_key"])
|
||||
down_proj_weight_scale = _extract_scale_tensor(
|
||||
weight_key_map["down_proj_expert_weight_scale_key"])
|
||||
up_gate_proj_in_scale = _extract_scale_tensor(
|
||||
weight_key_map["up_gate_proj_expert_in_scale_key"])
|
||||
down_proj_in_scale = _extract_scale_tensor(
|
||||
weight_key_map["down_proj_expert_in_scale_key"])
|
||||
|
||||
for idx, weight_tensor in enumerate([
|
||||
ffn1_tensor, ffn2_tensor, moe_ffn1_weight_scale,
|
||||
moe_ffn2_weight_scale, moe_ffn1_in_scale, moe_ffn2_in_scale
|
||||
up_gate_proj_tensor, down_proj_tensor, up_gate_proj_weight_scale,
|
||||
down_proj_weight_scale, up_gate_proj_in_scale, down_proj_in_scale
|
||||
]):
|
||||
name = added_wfp8afp8_attrs[idx]
|
||||
setattr(
|
||||
@@ -341,12 +341,12 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
ffn1_out = paddle.empty(
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
config_ffn1 = {
|
||||
config_up_gate_proj = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
@@ -354,15 +354,15 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
}
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||
topk_ids, num_local_experts, config_ffn1["BLOCK_SIZE_M"])
|
||||
topk_ids, num_local_experts, config_up_gate_proj["BLOCK_SIZE_M"])
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config_ffn1["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config_ffn1["BLOCK_SIZE_N"]), )
|
||||
ceil_div(max_possible_num_post_padded, config_up_gate_proj["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config_up_gate_proj["BLOCK_SIZE_N"]), )
|
||||
|
||||
permute_x = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
||||
x,
|
||||
scale=layer.moe_ffn1_in_scale,
|
||||
scale=layer.up_gate_proj_in_scale,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
intermediate_size=hidden_size,
|
||||
@@ -370,10 +370,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
permute_x,
|
||||
layer.moe_ffn1_weight,
|
||||
ffn1_out,
|
||||
layer.moe_ffn1_in_scale,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.up_gate_proj_weight,
|
||||
up_gate_proj_out,
|
||||
layer.up_gate_proj_in_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -384,11 +384,11 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
K=hidden_size,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||
stride_cm=ffn1_out.strides[0],
|
||||
stride_cn=ffn1_out.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[1],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[2],
|
||||
stride_cm=up_gate_proj_out.strides[0],
|
||||
stride_cn=up_gate_proj_out.strides[1],
|
||||
#
|
||||
stride_asm=-1, # only used in blockwise fp8
|
||||
stride_ask=-1, # only used in blockwise fp8
|
||||
@@ -398,51 +398,51 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config_ffn1["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config_ffn1["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config_ffn1["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config_ffn1["GROUP_SIZE_M"],
|
||||
BLOCK_SIZE_M=config_up_gate_proj["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config_up_gate_proj["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config_up_gate_proj["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config_up_gate_proj["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
|
||||
even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
ffn2_input = paddle.incubate.nn.functional.swiglu(
|
||||
ffn1_out)
|
||||
down_proj_input = paddle.incubate.nn.functional.swiglu(
|
||||
up_gate_proj_out)
|
||||
|
||||
ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
||||
ffn2_input,
|
||||
scale=layer.moe_ffn2_in_scale,
|
||||
down_proj_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
||||
down_proj_input,
|
||||
scale=layer.down_proj_in_scale,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
intermediate_size=moe_intermediate_size,
|
||||
tiled=True)
|
||||
|
||||
config_ffn2 = {
|
||||
config_down_proj = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
|
||||
ffn2_out = paddle.empty(
|
||||
down_proj_out = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
|
||||
ceil_div(max_possible_num_post_padded, config_down_proj["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config_down_proj["BLOCK_SIZE_N"]), )
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
ffn2_input,
|
||||
layer.moe_ffn2_weight,
|
||||
ffn2_out,
|
||||
layer.moe_ffn2_in_scale,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
down_proj_input,
|
||||
layer.down_proj_weight,
|
||||
down_proj_out,
|
||||
layer.down_proj_in_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -451,13 +451,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=ffn2_input.strides[0],
|
||||
stride_ak=ffn2_input.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||
stride_cm=ffn2_out.strides[0],
|
||||
stride_cn=ffn2_out.strides[1],
|
||||
stride_am=down_proj_input.strides[0],
|
||||
stride_ak=down_proj_input.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[1],
|
||||
stride_bn=layer.down_proj_weight.strides[2],
|
||||
stride_cm=down_proj_out.strides[0],
|
||||
stride_cn=down_proj_out.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=-1,
|
||||
@@ -466,20 +466,20 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config_ffn2["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config_ffn2["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config_ffn2["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config_ffn2["GROUP_SIZE_M"],
|
||||
BLOCK_SIZE_M=config_down_proj["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config_down_proj["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config_down_proj["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config_down_proj["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=True,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
|
||||
even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
ffn2_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = ffn2_out.sum(axis=1)
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
@@ -496,9 +496,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
"up_gate_proj_weight_scale", "down_proj_weight_scale"
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
@@ -510,11 +510,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
@@ -537,14 +537,14 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
[0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
|
||||
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
|
||||
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
|
||||
"""
|
||||
check layer is valid for this method
|
||||
"""
|
||||
assert ffn1_weights[0].shape == [
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
@@ -563,8 +563,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
num_local_experts = layer.num_local_experts
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
E, N1, _ = layer.moe_ffn1_weight.shape
|
||||
N2 = layer.moe_ffn2_weight.shape[1]
|
||||
E, N1, _ = layer.up_gate_proj_weight.shape
|
||||
N2 = layer.down_proj_weight.shape[1]
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
@@ -605,10 +605,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
|
||||
layer.up_gate_proj_weight.view(paddle.float8_e4m3fn),
|
||||
intermediate_cache1,
|
||||
x_scale,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -619,17 +619,17 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
K=hidden_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[2],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[2],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[1],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
#
|
||||
stride_asm=x_scale.strides[0], # only used in blockwise fp8
|
||||
stride_ask=x_scale.strides[1], # only used in blockwise fp8
|
||||
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||
stride_bsk=layer.moe_ffn1_weight_scale.strides[2],
|
||||
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.up_gate_proj_weight_scale.strides[2],
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
group_n=self.quant_config.weight_block_size[1],
|
||||
group_k=self.quant_config.weight_block_size[0],
|
||||
# Meta-parameters
|
||||
@@ -656,10 +656,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn),
|
||||
layer.down_proj_weight.view(paddle.float8_e4m3fn),
|
||||
intermediate_cache3,
|
||||
x_scale,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -670,16 +670,16 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
K=moe_intermediate_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[2],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[2],
|
||||
stride_bn=layer.down_proj_weight.strides[1],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_asm=x_scale.strides[0], # only used in blockwise fp8
|
||||
stride_ask=x_scale.strides[1], # only used in blockwise fp8
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
stride_bsk=layer.moe_ffn2_weight_scale.strides[2],
|
||||
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.down_proj_weight_scale.strides[2],
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
group_n=self.quant_config.weight_block_size[1],
|
||||
group_k=self.quant_config.weight_block_size[0],
|
||||
# Meta-parameters
|
||||
|
@@ -41,16 +41,16 @@ class Wint2MoeMethod(QuantMethodBase):
|
||||
"""
|
||||
pass
|
||||
|
||||
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
|
||||
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
|
||||
"""
|
||||
check layer is valid for this method
|
||||
"""
|
||||
assert len(
|
||||
ffn1_weights
|
||||
) == layer.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
|
||||
up_gate_proj_weights
|
||||
) == layer.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts."
|
||||
assert len(
|
||||
ffn2_weights
|
||||
) == layer.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
|
||||
down_proj_weights
|
||||
) == layer.num_local_experts, "down_proj_weights length should be equal to num_local_experts."
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -78,96 +78,96 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
ffn1_expert_super_scales_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_super_scales_key", None)
|
||||
ffn2_expert_super_scales_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_super_scales_key", None)
|
||||
ffn1_expert_code_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_code_scale_key", None)
|
||||
ffn2_expert_code_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_code_scale_key", None)
|
||||
ffn1_expert_code_zp_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_code_zp_key", None)
|
||||
ffn2_expert_code_zp_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_code_zp_key", None)
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_super_scales_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_super_scales_key", None)
|
||||
down_proj_expert_super_scales_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_super_scales_key", None)
|
||||
up_gate_proj_expert_code_scale_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_code_scale_key", None)
|
||||
down_proj_expert_code_scale_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_code_scale_key", None)
|
||||
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_code_zp_key", None)
|
||||
down_proj_expert_code_zp_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_code_zp_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
ffn1_super_scales = []
|
||||
ffn2_super_scales = []
|
||||
ffn1_code_scale = []
|
||||
ffn2_code_scale = []
|
||||
ffn1_code_zp = []
|
||||
ffn2_code_zp = []
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
up_gate_proj_super_scales = []
|
||||
down_proj_super_scales = []
|
||||
up_gate_proj_code_scale = []
|
||||
down_proj_code_scale = []
|
||||
up_gate_proj_code_zp = []
|
||||
down_proj_code_zp = []
|
||||
for i in range(layer.num_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn1_super_scales.append(
|
||||
down_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
up_gate_proj_super_scales.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_super_scales_key.format(expert_idx))))
|
||||
ffn2_super_scales.append(
|
||||
up_gate_proj_expert_super_scales_key.format(expert_idx))))
|
||||
down_proj_super_scales.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_super_scales_key.format(expert_idx))))
|
||||
ffn1_code_scale.append(
|
||||
down_proj_expert_super_scales_key.format(expert_idx))))
|
||||
up_gate_proj_code_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_code_scale_key.format(expert_idx))))
|
||||
ffn2_code_scale.append(
|
||||
up_gate_proj_expert_code_scale_key.format(expert_idx))))
|
||||
down_proj_code_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_code_scale_key.format(expert_idx))))
|
||||
ffn1_code_zp.append(
|
||||
down_proj_expert_code_scale_key.format(expert_idx))))
|
||||
up_gate_proj_code_zp.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_code_zp_key.format(expert_idx))))
|
||||
ffn2_code_zp.append(
|
||||
up_gate_proj_expert_code_zp_key.format(expert_idx))))
|
||||
down_proj_code_zp.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_code_zp_key.format(expert_idx))))
|
||||
down_proj_expert_code_zp_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
ffn1_super_scales = paddle.stack(ffn1_super_scales, axis=0)
|
||||
ffn2_super_scales = paddle.stack(ffn2_super_scales, axis=0)
|
||||
ffn1_code_scale = paddle.stack(ffn1_code_scale, axis=0)
|
||||
ffn2_code_scale = paddle.stack(ffn2_code_scale, axis=0)
|
||||
ffn1_code_zp = paddle.stack(ffn1_code_zp, axis=0)
|
||||
ffn2_code_zp = paddle.stack(ffn2_code_zp, axis=0)
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
up_gate_proj_super_scales = paddle.stack(up_gate_proj_super_scales, axis=0)
|
||||
down_proj_super_scales = paddle.stack(down_proj_super_scales, axis=0)
|
||||
up_gate_proj_code_scale = paddle.stack(up_gate_proj_code_scale, axis=0)
|
||||
down_proj_code_scale = paddle.stack(down_proj_code_scale, axis=0)
|
||||
up_gate_proj_code_zp = paddle.stack(up_gate_proj_code_zp, axis=0)
|
||||
down_proj_code_zp = paddle.stack(down_proj_code_zp, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale,
|
||||
"moe_ffn1_super_scales": ffn1_super_scales,
|
||||
"moe_ffn2_super_scales": ffn2_super_scales,
|
||||
"moe_ffn1_code_scale": ffn1_code_scale,
|
||||
"moe_ffn2_code_scale": ffn2_code_scale,
|
||||
"moe_ffn1_code_zp": ffn1_code_zp,
|
||||
"moe_ffn2_code_zp": ffn2_code_zp
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
"up_gate_proj_super_scales": up_gate_proj_super_scales,
|
||||
"down_proj_super_scales": down_proj_super_scales,
|
||||
"up_gate_proj_code_scale": up_gate_proj_code_scale,
|
||||
"down_proj_code_scale": down_proj_code_scale,
|
||||
"up_gate_proj_code_zp": up_gate_proj_code_zp,
|
||||
"down_proj_code_zp": down_proj_code_zp
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
@@ -200,7 +200,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
@@ -210,17 +210,17 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None,
|
||||
layer.moe_ffn1_super_scales,
|
||||
layer.moe_ffn2_super_scales,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.moe_ffn1_code_scale,
|
||||
layer.moe_ffn1_code_zp,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.moe_ffn2_code_scale,
|
||||
layer.moe_ffn2_code_zp,
|
||||
layer.up_gate_proj_super_scales,
|
||||
layer.down_proj_super_scales,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
layer.up_gate_proj_code_scale,
|
||||
layer.up_gate_proj_code_zp,
|
||||
layer.down_proj_weight_scale,
|
||||
layer.down_proj_code_scale,
|
||||
layer.down_proj_code_zp,
|
||||
False,
|
||||
)
|
||||
|
||||
@@ -271,7 +271,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
)
|
||||
|
||||
num_tokens, K = x.shape
|
||||
E, _, N = layer.moe_ffn1_weight.shape
|
||||
E, _, N = layer.up_gate_proj_weight.shape
|
||||
M = num_tokens
|
||||
|
||||
top_k = topk_ids.shape[1]
|
||||
@@ -308,12 +308,12 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
|
||||
moe_wint2_ffn_kernel[grid](
|
||||
x,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.up_gate_proj_weight,
|
||||
intermediate_cache1,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.moe_ffn1_super_scales,
|
||||
layer.moe_ffn1_code_scale,
|
||||
layer.moe_ffn1_code_zp,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
layer.up_gate_proj_super_scales,
|
||||
layer.up_gate_proj_code_scale,
|
||||
layer.up_gate_proj_code_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -321,7 +321,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
num_valid_tokens,
|
||||
max_possible_num_post_padded,
|
||||
# Matrix dimensions
|
||||
N=layer.moe_ffn1_weight.shape[-1],
|
||||
N=layer.up_gate_proj_weight.shape[-1],
|
||||
K=x.shape[-1],
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
@@ -329,15 +329,15 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
# (A has M rows).
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[1],
|
||||
stride_bn=1,
|
||||
stride_cm=intermediate_cache1.strides[-2],
|
||||
stride_cn=1,
|
||||
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||
stride_bsk=layer.moe_ffn1_weight_scale.strides[1],
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.up_gate_proj_weight_scale.strides[1],
|
||||
stride_bsn=1,
|
||||
stride_bce=layer.moe_ffn1_code_scale.strides[0],
|
||||
stride_bce=layer.up_gate_proj_code_scale.strides[0],
|
||||
stride_bck=1,
|
||||
stride_bcn=1,
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
@@ -361,17 +361,17 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
}
|
||||
|
||||
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(layer.moe_ffn2_weight.shape[-1], config["BLOCK_SIZE_N"]), )
|
||||
ceil_div(layer.down_proj_weight.shape[-1], config["BLOCK_SIZE_N"]), )
|
||||
|
||||
|
||||
moe_wint2_ffn_kernel[grid](
|
||||
intermediate_cache2,
|
||||
layer.moe_ffn2_weight,
|
||||
layer.down_proj_weight,
|
||||
intermediate_cache3,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.moe_ffn2_super_scales,
|
||||
layer.moe_ffn2_code_scale,
|
||||
layer.moe_ffn2_code_zp,
|
||||
layer.down_proj_weight_scale,
|
||||
layer.down_proj_super_scales,
|
||||
layer.down_proj_code_scale,
|
||||
layer.down_proj_code_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -379,7 +379,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
num_valid_tokens,
|
||||
max_possible_num_post_padded,
|
||||
# Matrix dimensions
|
||||
N=layer.moe_ffn2_weight.shape[-1],
|
||||
N=layer.down_proj_weight.shape[-1],
|
||||
K=intermediate_cache2.shape[-1],
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
@@ -387,15 +387,15 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
# (A has M rows).
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=1,
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[1],
|
||||
stride_bn=1,
|
||||
stride_cm=intermediate_cache3.strides[-2],
|
||||
stride_cn=1,
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
stride_bsk=layer.moe_ffn2_weight_scale.strides[1],
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.down_proj_weight_scale.strides[1],
|
||||
stride_bsn=1,
|
||||
stride_bce=layer.moe_ffn2_code_scale.strides[0],
|
||||
stride_bce=layer.down_proj_code_scale.strides[0],
|
||||
stride_bck=1,
|
||||
stride_bcn=1,
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
|
@@ -38,14 +38,14 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
for weights in [ffn1_weights, ffn2_weights]:
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
for weights in [up_gate_proj_weights, down_proj_weights]:
|
||||
for idx, weight in enumerate(weights):
|
||||
weights[idx] = weight.transpose([1, 0])
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
@@ -71,13 +71,13 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
x,
|
||||
layer.gate_weight.transpose([1, 0]),
|
||||
layer.gate_correction_bias,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None, # ffn1 bias
|
||||
None, # ffn2 bias
|
||||
None, # ffn1 scale
|
||||
None, # ffn2 scale
|
||||
None, # ffn1_in_scale
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None, # up_gate_proj bias
|
||||
None, # down_proj bias
|
||||
None, # up_gate_proj scale
|
||||
None, # down_proj scale
|
||||
None, # up_gate_proj_in_scale
|
||||
"", # moe_quant_type
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
@@ -129,20 +129,20 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert ffn1_weights[0].shape == [
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
|
||||
added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
added_scale_attrs = ["up_gate_proj_weight_scale", "down_proj_weight_scale"]
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = added_weight_attrs[idx]
|
||||
scale_name = added_scale_attrs[idx]
|
||||
|
||||
@@ -189,16 +189,16 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
||||
x,
|
||||
layer.gate_weight.transpose([1, 0]),
|
||||
layer.gate_correction_bias,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None, # ffn1 bias
|
||||
None, # ffn2 bias
|
||||
(layer.moe_ffn1_weight_scale
|
||||
if hasattr(layer, "moe_ffn1_weight_scale") else None),
|
||||
(layer.moe_ffn2_weight_scale
|
||||
if hasattr(layer, "moe_ffn2_weight_scale") else None),
|
||||
(layer.moe_ffn2_in_scale
|
||||
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None, # up_gate_proj bias
|
||||
None, # down_proj bias
|
||||
(layer.up_gate_proj_weight_scale
|
||||
if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale
|
||||
if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
(layer.down_proj_in_scale
|
||||
if hasattr(layer, "down_proj_in_scale") else None),
|
||||
self.moe_quant_type,
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
|
@@ -145,13 +145,13 @@ class FusedMoE(nn.Layer):
|
||||
shape=gate_correction_bias_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
ffn1_output_dim = self.moe_intermediate_size * 2
|
||||
up_gate_proj_output_dim = self.moe_intermediate_size * 2
|
||||
if self.moe_quant_type in ["fp8", "wint8"]:
|
||||
ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size]
|
||||
ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
|
||||
up_gate_proj_weight_shape = [self.num_local_experts, up_gate_proj_output_dim, self.hidden_size]
|
||||
down_proj_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
|
||||
else:
|
||||
ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim]
|
||||
ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
|
||||
up_gate_proj_weight_shape = [self.num_local_experts, self.hidden_size, up_gate_proj_output_dim]
|
||||
down_proj_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
|
||||
|
||||
# Create parameters
|
||||
if self.moe_quant_type == "fp8":
|
||||
@@ -161,15 +161,15 @@ class FusedMoE(nn.Layer):
|
||||
self.weight_dtype = "int8"
|
||||
self.init_weight_only_scale()
|
||||
|
||||
# FFN1 parameters
|
||||
self.moe_ffn1_weight = self.create_parameter(
|
||||
shape=ffn1_weight_shape,
|
||||
# up_gate_proj parameters
|
||||
self.up_gate_proj_weight = self.create_parameter(
|
||||
shape=up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
# FFN2 parameters
|
||||
self.moe_ffn2_weight = self.create_parameter(
|
||||
shape=ffn2_weight_shape,
|
||||
# down_proj parameters
|
||||
self.down_proj_weight = self.create_parameter(
|
||||
shape=down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
@@ -178,44 +178,44 @@ class FusedMoE(nn.Layer):
|
||||
"""
|
||||
Initialize the weight scale.
|
||||
"""
|
||||
self.moe_ffn1_weight_scale = self.create_parameter(
|
||||
self.up_gate_proj_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.moe_ffn2_weight_scale = self.create_parameter(
|
||||
self.down_proj_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
def load_experts_weight(self, state_dict: dict,
|
||||
ffn1_expert_weight_key: str,
|
||||
ffn2_expert_weight_key: str):
|
||||
up_gate_proj_expert_weight_key: str,
|
||||
down_proj_expert_weight_key: str):
|
||||
"""
|
||||
Load experts weight from state_dict.
|
||||
Args:
|
||||
state_dict (dict): The state_dict of model.
|
||||
ffn1_expert_weight_key (str): The key of ffn1 expert weight.
|
||||
ffn2_expert_weight_key (str): The key of ffn2 expert weight.
|
||||
up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight.
|
||||
down_proj_expert_weight_key (str): The key of down_proj expert weight.
|
||||
"""
|
||||
ffn1_weights = []
|
||||
ffn2_weights = []
|
||||
is_ffn_merged = ffn1_expert_weight_key.format(
|
||||
up_gate_proj_weights = []
|
||||
down_proj_weights = []
|
||||
is_ffn_merged = up_gate_proj_expert_weight_key.format(
|
||||
self.expert_id_offset) in state_dict
|
||||
if is_ffn_merged:
|
||||
for i in range(self.num_local_experts):
|
||||
expert_idx = self.expert_id_offset + i
|
||||
ffn1_weights.append(
|
||||
up_gate_proj_weights.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_key.format(expert_idx))))
|
||||
ffn2_weights.append(
|
||||
up_gate_proj_expert_weight_key.format(expert_idx))))
|
||||
down_proj_weights.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_key.format(expert_idx))))
|
||||
down_proj_expert_weight_key.format(expert_idx))))
|
||||
else:
|
||||
gate_expert_weight_key = ffn1_expert_weight_key.replace(
|
||||
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace(
|
||||
"up_gate_proj", "gate_proj")
|
||||
up_expert_weight_key = ffn1_expert_weight_key.replace(
|
||||
up_expert_weight_key = up_gate_proj_expert_weight_key.replace(
|
||||
"up_gate_proj", "up_proj")
|
||||
for j in range(self.num_local_experts):
|
||||
expert_idx = self.expert_id_offset + j
|
||||
@@ -223,12 +223,12 @@ class FusedMoE(nn.Layer):
|
||||
state_dict.pop(gate_expert_weight_key.format(expert_idx)))
|
||||
up = get_tensor(
|
||||
state_dict.pop(up_expert_weight_key.format(expert_idx)))
|
||||
ffn1_weights.append(paddle.concat([gate, up], axis=-1))
|
||||
ffn2_weights.append(
|
||||
up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1))
|
||||
down_proj_weights.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_key.format(expert_idx))))
|
||||
return ffn1_weights, ffn2_weights
|
||||
down_proj_expert_weight_key.format(expert_idx))))
|
||||
return up_gate_proj_weights, down_proj_weights
|
||||
|
||||
def extract_moe_ffn_weights(self, state_dict: dict):
|
||||
"""
|
||||
@@ -239,30 +239,30 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing two lists:
|
||||
- ffn1_weights: List of tensors for first FFN layer weights
|
||||
- ffn2_weights: List of tensors for second FFN layer weights
|
||||
- up_gate_proj_weights: List of tensors for first FFN layer weights
|
||||
- down_proj_weights: List of tensors for second FFN layer weights
|
||||
|
||||
Raises:
|
||||
AssertionError: If required weight keys are missing or number of weights
|
||||
doesn't match number of local experts.
|
||||
"""
|
||||
ffn1_expert_weight_key = self.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = self.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
assert ffn1_expert_weight_key is not None, "ffn1_expert_weight_key should not be none."
|
||||
assert ffn2_expert_weight_key is not None, "ffn2_expert_weight_key should not be none."
|
||||
up_gate_proj_expert_weight_key = self.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = self.weight_key_map.get(
|
||||
"down_proj_expert_weight_key", None)
|
||||
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
||||
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
||||
|
||||
ffn1_weights, ffn2_weights = self.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
up_gate_proj_weights, down_proj_weights = self.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
|
||||
assert len(
|
||||
ffn1_weights
|
||||
) == self.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
|
||||
up_gate_proj_weights
|
||||
) == self.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts."
|
||||
assert len(
|
||||
ffn2_weights
|
||||
) == self.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
|
||||
down_proj_weights
|
||||
) == self.num_local_experts, "down_proj_weights length should be equal to num_local_experts."
|
||||
|
||||
return ffn1_weights, ffn2_weights
|
||||
return up_gate_proj_weights, down_proj_weights
|
||||
|
||||
def extract_gate_correction_bias(self, gate_correction_bias_key,
|
||||
state_dict):
|
||||
|
@@ -46,11 +46,11 @@ class ParallelEHProjection(nn.Layer):
|
||||
prefix (str): full name of the layer in the state dict
|
||||
"""
|
||||
super(ParallelEHProjection, self).__init__()
|
||||
self.linear_weight_key = prefix + ".weight"
|
||||
self.weight_key = prefix + ".weight"
|
||||
if with_bias:
|
||||
self.linear_bias_key = prefix + ".bias"
|
||||
self.bias_key = prefix + ".bias"
|
||||
else:
|
||||
self.linear_bias_key = None
|
||||
self.bias_key = None
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.column_cut = True
|
||||
|
||||
@@ -66,26 +66,26 @@ class ParallelEHProjection(nn.Layer):
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.out_linear = ColumnParallelLinear(
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
else:
|
||||
self.out_linear = RowParallelLinear(
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
@@ -100,20 +100,20 @@ class ParallelEHProjection(nn.Layer):
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
weight_tensor = get_tensor(
|
||||
state_dict.pop(self.linear_weight_key)).astype(
|
||||
state_dict.pop(self.weight_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
if self.out_linear.weight.shape != weight_tensor.shape:
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.out_linear.weight.set_value(weight_tensor)
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.linear_bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
self.out_linear.bias.set_value(bias)
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
@@ -129,5 +129,5 @@ class ParallelEHProjection(nn.Layer):
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = self.out_linear(logits)
|
||||
logits = self.linear(logits)
|
||||
return logits
|
||||
|
@@ -43,7 +43,7 @@ class RMSNorm(nn.Layer):
|
||||
hidden_size: int,
|
||||
eps: float = 1e-5,
|
||||
prefix: str = "",
|
||||
linear_bias: paddle.Tensor = None,
|
||||
bias: paddle.Tensor = None,
|
||||
quant_scale: float = None,
|
||||
begin_norm_axis: int = 1,
|
||||
) -> None:
|
||||
@@ -57,7 +57,7 @@ class RMSNorm(nn.Layer):
|
||||
hidden_size (int) : size of hidden state.
|
||||
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
|
||||
prefix(str,optional):The name of current layer. Defaults to "".
|
||||
linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
|
||||
begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1.
|
||||
|
||||
@@ -78,7 +78,7 @@ class RMSNorm(nn.Layer):
|
||||
self.norm_func: Callable = fused_add_rms_norm
|
||||
else:
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self.bias: Optional[paddle.Tensor] = bias
|
||||
self.quant_scale: Optional[float] = quant_scale
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = self._dtype
|
||||
@@ -94,9 +94,9 @@ class RMSNorm(nn.Layer):
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
|
||||
self.ln_weight = None
|
||||
self.weight = None
|
||||
if self.with_weight:
|
||||
self.ln_weight = self.create_parameter(
|
||||
self.weight = self.create_parameter(
|
||||
shape=[self.hidden_size],
|
||||
default_initializer=nn.initializer.Constant(value=1.0),
|
||||
dtype=self._norm_weight_dtype,
|
||||
@@ -115,7 +115,7 @@ class RMSNorm(nn.Layer):
|
||||
weight_tensor = paddle.cast(
|
||||
get_tensor(state_dict.pop(self.weight_key)),
|
||||
self._norm_weight_dtype)
|
||||
self.ln_weight.set_value(weight_tensor)
|
||||
self.weight.set_value(weight_tensor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -139,18 +139,18 @@ class RMSNorm(nn.Layer):
|
||||
"""
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is None:
|
||||
return rms_norm(x, self.ln_weight, self.eps)
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
norm_out = self.norm_func(
|
||||
x, residual_input, self.ln_weight, self.eps
|
||||
x, residual_input, self.weight, self.eps
|
||||
)
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_weight=self.weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.linear_bias,
|
||||
bias=self.bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
@@ -174,7 +174,7 @@ class LayerNorm(nn.Layer):
|
||||
hidden_size: int,
|
||||
eps: float = 1e-5,
|
||||
prefix="",
|
||||
linear_bias: paddle.Tensor = None,
|
||||
bias: paddle.Tensor = None,
|
||||
quant_scale: float = None,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
@@ -189,7 +189,7 @@ class LayerNorm(nn.Layer):
|
||||
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
|
||||
with_bias (bool):Whether to include bias or not. Defaults to False.
|
||||
Raises:
|
||||
@@ -212,7 +212,7 @@ class LayerNorm(nn.Layer):
|
||||
self.norm_func: Callable = paddle.nn.functional.layer_norm
|
||||
else:
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self.bias: Optional[paddle.Tensor] = bias
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = "float32"
|
||||
|
||||
@@ -227,16 +227,16 @@ class LayerNorm(nn.Layer):
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
|
||||
self.ln_weight = None
|
||||
self.weight = None
|
||||
if self.with_weight:
|
||||
self.ln_weight = self.create_parameter(
|
||||
self.weight = self.create_parameter(
|
||||
shape=[self.hidden_size],
|
||||
default_initializer=nn.initializer.Constant(value=1.0),
|
||||
dtype=self._norm_weight_dtype,
|
||||
)
|
||||
self.ln_bias = None
|
||||
self.bias = None
|
||||
if self.with_bias:
|
||||
self.ln_bias = self.create_parameter(
|
||||
self.bias = self.create_parameter(
|
||||
shape=[self.hidden_size],
|
||||
is_bias=True,
|
||||
dtype=self._norm_weight_dtype,
|
||||
@@ -255,14 +255,14 @@ class LayerNorm(nn.Layer):
|
||||
weight_tensor = paddle.cast(
|
||||
get_tensor(state_dict.pop(self.weight_key)),
|
||||
self._norm_weight_dtype)
|
||||
self.ln_weight.set_value(weight_tensor)
|
||||
self.weight.set_value(weight_tensor)
|
||||
|
||||
# bias
|
||||
if self.with_bias:
|
||||
bias_tensor = paddle.cast(
|
||||
get_tensor(state_dict.pop(self.bias_key)),
|
||||
self._norm_weight_dtype)
|
||||
self.ln_bias.set_value(bias_tensor)
|
||||
self.bias.set_value(bias_tensor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -285,10 +285,10 @@ class LayerNorm(nn.Layer):
|
||||
operations (like linear transformation) on the `residual_input`.
|
||||
"""
|
||||
if current_platform.is_iluvatar():
|
||||
if self.ln_weight is None and self.ln_bias is None:
|
||||
if self.weight is None and self.bias is None:
|
||||
out = x
|
||||
if self.linear_bias is not None:
|
||||
out += self.linear_bias
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
if residual_input is not None:
|
||||
out += residual_input
|
||||
return out, out
|
||||
@@ -303,8 +303,8 @@ class LayerNorm(nn.Layer):
|
||||
out = self.norm_func(
|
||||
x=y,
|
||||
normalized_shape=y.shape[1:],
|
||||
weight=self.ln_weight,
|
||||
bias=self.linear_bias,
|
||||
weight=self.weight,
|
||||
bias=self.bias,
|
||||
epsilon=self.eps,
|
||||
)
|
||||
return out, y
|
||||
@@ -312,19 +312,19 @@ class LayerNorm(nn.Layer):
|
||||
out = self.norm_func(
|
||||
x=x,
|
||||
normalized_shape=x.shape[1:],
|
||||
weight=self.ln_weight,
|
||||
bias=self.linear_bias,
|
||||
weight=self.weight,
|
||||
bias=self.bias,
|
||||
epsilon=self.eps,
|
||||
)
|
||||
return out
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=self.ln_bias,
|
||||
norm_weight=self.weight,
|
||||
norm_bias=self.bias,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
bias=self.bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
|
@@ -78,8 +78,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] -
|
||||
1) // self.quant_config.weight_block_size[0],
|
||||
@@ -95,8 +95,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
quanted_weight_tensor, weight_block_scale_tensor = (
|
||||
per_block_cast_to_fp8(weight_tensor))
|
||||
layer.linear_weight.copy_(quanted_weight_tensor, False)
|
||||
layer.linear_weight_scale.set_value(weight_block_scale_tensor)
|
||||
layer.weight.copy_(quanted_weight_tensor, False)
|
||||
layer.weight_scale.set_value(weight_block_scale_tensor)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict):
|
||||
"""
|
||||
@@ -106,10 +106,10 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
|
||||
quant_weight = quant_weight.transpose([1, 0]).contiguous()
|
||||
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
weight_scale = weight_scale.transpose([1, 0])
|
||||
layer.linear_weight_scale.set_value(weight_scale)
|
||||
layer.weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
|
||||
@@ -119,9 +119,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(
|
||||
(x, x_scale_tensor),
|
||||
(layer.linear_weight, layer.linear_weight_scale),
|
||||
(layer.weight, layer.weight_scale),
|
||||
linear_out,
|
||||
)
|
||||
if layer.with_bias:
|
||||
linear_out = paddle.add(linear_out, layer.linear_bias)
|
||||
linear_out = paddle.add(linear_out, layer.bias)
|
||||
return linear_out
|
||||
|
@@ -96,7 +96,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
act_scale = get_tensor(state_dict.pop(layer.act_scale_key))
|
||||
|
||||
quant_weight = quant_weight.transpose([1, 0]).contiguous()
|
||||
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
self.act_scale = act_scale.item()
|
||||
self.total_scale = (act_scale * weight_scale).item()
|
||||
@@ -118,7 +118,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
|
||||
linear_out = cutlass_fp8_fp8_half_gemm_fused(
|
||||
fp8_x,
|
||||
layer.linear_weight,
|
||||
layer.weight,
|
||||
transpose_x=False,
|
||||
transpose_y=True,
|
||||
bias=None,
|
||||
|
@@ -63,8 +63,8 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
pass
|
||||
|
||||
@@ -77,16 +77,16 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
scale_dtype="float16",
|
||||
))
|
||||
weight_scale_tensor = paddle.view(weight_scale_tensor, layer._dtype)
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(weight_scale_tensor)
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor)
|
||||
|
||||
def apply(self, layer, x):
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16(
|
||||
x,
|
||||
layer.linear_weight,
|
||||
layer.linear_weight_scale,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
zero_points=None,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix +
|
||||
".weight_scale")
|
||||
/ (self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
|
@@ -69,7 +69,7 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "int8"
|
||||
if self.quant_config.use_smooth_quant:
|
||||
self.smooth_quant_method.create_weights(layer)
|
||||
@@ -101,21 +101,21 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
if self.skip_quant:
|
||||
logger.debug(f"{layer.prefix} skip quant")
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
layer.weight.set_value(weight_tensor)
|
||||
else:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
weight_tensor = paddle.cast(weight_tensor, "int8")
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
layer.weight.set_value(weight_tensor)
|
||||
|
||||
def apply(self, layer, x):
|
||||
if self.skip_quant:
|
||||
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
return linear_out
|
||||
if self.quant_config.use_gemm_dequant:
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.gemm_dequant(
|
||||
x, layer.linear_weight, layer.linear_out_scale, layer._dtype)
|
||||
x, layer.weight, layer.linear_out_scale, layer._dtype)
|
||||
else:
|
||||
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8(
|
||||
linear_out, layer.linear_out_scale, layer._dtype)
|
||||
return linear_out
|
||||
|
@@ -77,12 +77,12 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
return GCUWeightOnlyLinearMethod(self)
|
||||
elif current_platform.is_dcu():
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
DCUTritonWeightOnlyMoEMethod)
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
DCUTritonWeightOnlyMoEMethod
|
||||
return DCUTritonWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
DCUWeightOnlyLinearMethod)
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
DCUWeightOnlyLinearMethod
|
||||
return DCUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
if isinstance(layer, FusedMoE):
|
||||
@@ -152,14 +152,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
def create_weights(self, layer):
|
||||
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
@@ -171,9 +171,9 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
def apply(self, layer, x):
|
||||
linear_out = weight_only_linear(
|
||||
x,
|
||||
weight=layer.linear_weight,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
weight_scale=layer.linear_weight_scale,
|
||||
weight=layer.weight,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_dtype="int8"
|
||||
if self.quant_config.name() == "wint8" else "int4",
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
@@ -204,8 +204,8 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.linear_weight.set_value(quant_weight)
|
||||
layer.linear_weight_scale.set_value(
|
||||
layer.weight.set_value(quant_weight)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
@@ -216,6 +216,6 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
)
|
||||
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
@@ -70,11 +70,11 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
def create_weights(self, layer):
|
||||
"""
|
||||
"""
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||
self.skip_quant = False
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=[1],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
@@ -86,7 +86,7 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
"""
|
||||
if self.skip_quant:
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
layer.weight.set_value(weight_tensor)
|
||||
return
|
||||
if weights.dtype != paddle.float8_e4m3fn:
|
||||
self.use_per_token_if_dynamic = True
|
||||
@@ -95,22 +95,22 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
weight_tensor,
|
||||
use_per_token_if_dynamic=False,
|
||||
)
|
||||
layer.linear_weight.copy_(qweight, False)
|
||||
layer.linear_weight_scale.set_value(weight_scale)
|
||||
layer.weight.copy_(qweight, False)
|
||||
layer.weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
"""
|
||||
"""
|
||||
if self.skip_quant:
|
||||
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
return linear_out
|
||||
if self.use_per_token_if_dynamic:
|
||||
out_type = x.dtype
|
||||
a_q, a_scales = scaled_fp8_quant(
|
||||
x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
linear_out = cutlass_scaled_mm(a_q, layer.linear_weight, a_scales,
|
||||
layer.linear_weight_scale, out_type,
|
||||
layer.linear_bias)
|
||||
linear_out = cutlass_scaled_mm(a_q, layer.weight, a_scales,
|
||||
layer.weight_scale, out_type,
|
||||
layer.bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return linear_out
|
||||
|
Reference in New Issue
Block a user