mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Bug fix] fix complie bug when sm < 89 (#2738)
This commit is contained in:
@@ -129,18 +129,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
intermediate_cache1 = paddle.empty(
|
||||
ffn1_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(token_num * top_k, moe_intermediate_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
@@ -158,7 +150,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
layer.moe_ffn1_weight,
|
||||
intermediate_cache1,
|
||||
ffn1_out,
|
||||
None,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
None,
|
||||
@@ -174,8 +166,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
stride_cm=ffn1_out.strides[0],
|
||||
stride_cn=ffn1_out.strides[1],
|
||||
#
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
@@ -197,16 +189,21 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
ffn2_input = paddle.incubate.nn.functional.swiglu(
|
||||
ffn1_out)
|
||||
|
||||
ffn2_out = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
ffn2_input,
|
||||
layer.moe_ffn2_weight,
|
||||
intermediate_cache3,
|
||||
ffn2_out,
|
||||
None,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
topk_weights,
|
||||
@@ -217,13 +214,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_am=ffn2_input.strides[0],
|
||||
stride_ak=ffn2_input.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_cm=ffn2_out.strides[0],
|
||||
stride_cn=ffn2_out.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
@@ -244,8 +241,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
out = intermediate_cache3.sum(axis=1)
|
||||
ffn2_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = ffn2_out.sum(axis=1)
|
||||
return out
|
||||
|
||||
|
||||
@@ -343,18 +340,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
ffn1_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(token_num * top_k, moe_intermediate_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
config_ffn1 = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
@@ -381,7 +370,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
fused_moe_kernel_paddle[grid](
|
||||
permute_x,
|
||||
layer.moe_ffn1_weight,
|
||||
intermediate_cache1,
|
||||
ffn1_out,
|
||||
layer.moe_ffn1_in_scale,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
None,
|
||||
@@ -397,8 +386,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
stride_cm=ffn1_out.strides[0],
|
||||
stride_cn=ffn1_out.strides[1],
|
||||
#
|
||||
stride_asm=-1, # only used in blockwise fp8
|
||||
stride_ask=-1, # only used in blockwise fp8
|
||||
@@ -420,11 +409,11 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
ffn2_input = paddle.incubate.nn.functional.swiglu(
|
||||
ffn1_out)
|
||||
|
||||
intermediate_cache2 = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
||||
intermediate_cache2,
|
||||
ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
||||
ffn2_input,
|
||||
scale=layer.moe_ffn2_in_scale,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
@@ -438,14 +427,19 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
|
||||
ffn2_out = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
ffn2_input,
|
||||
layer.moe_ffn2_weight,
|
||||
intermediate_cache3,
|
||||
ffn2_out,
|
||||
layer.moe_ffn2_in_scale,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
topk_weights,
|
||||
@@ -456,13 +450,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_am=ffn2_input.strides[0],
|
||||
stride_ak=ffn2_input.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_cm=ffn2_out.strides[0],
|
||||
stride_cn=ffn2_out.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=-1,
|
||||
@@ -483,8 +477,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
out = intermediate_cache3.sum(axis=1)
|
||||
ffn2_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = ffn2_out.sum(axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
Reference in New Issue
Block a user