mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Bug fix] fix complie bug when sm < 89 (#2738)
This commit is contained in:
@@ -468,6 +468,7 @@ std::vector<paddle::Tensor> NoauxTc(
|
|||||||
int topk,
|
int topk,
|
||||||
float routed_scaling_factor);
|
float routed_scaling_factor);
|
||||||
|
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
||||||
const paddle::Tensor& x,
|
const paddle::Tensor& x,
|
||||||
const paddle::Tensor& y,
|
const paddle::Tensor& y,
|
||||||
@@ -489,6 +490,7 @@ paddle::Tensor MoeFusedHadamardQuantFp8Func(
|
|||||||
paddle::Tensor FusedHadamardQuantFp8Func(
|
paddle::Tensor FusedHadamardQuantFp8Func(
|
||||||
const paddle::Tensor &input,
|
const paddle::Tensor &input,
|
||||||
const float scale);
|
const float scale);
|
||||||
|
#endif
|
||||||
|
|
||||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||||
|
|
||||||
@@ -769,6 +771,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
|
|
||||||
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
||||||
|
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
|
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
|
||||||
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
|
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
|
||||||
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
|
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
|
||||||
@@ -780,4 +783,5 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
|
|
||||||
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
|
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
|
||||||
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
|
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
@@ -129,18 +129,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
True, # apply_norm_weight,
|
True, # apply_norm_weight,
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
intermediate_cache1 = paddle.empty(
|
ffn1_out = paddle.empty(
|
||||||
[token_num * top_k, moe_intermediate_size * 2],
|
[token_num * top_k, moe_intermediate_size * 2],
|
||||||
dtype=x.dtype,
|
dtype=x.dtype,
|
||||||
)
|
)
|
||||||
intermediate_cache2 = paddle.empty(
|
|
||||||
(token_num * top_k, moe_intermediate_size),
|
|
||||||
dtype=x.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache3 = paddle.empty(
|
|
||||||
(token_num * top_k, hidden_size),
|
|
||||||
dtype=x.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
@@ -158,7 +150,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
fused_moe_kernel_paddle[grid](
|
fused_moe_kernel_paddle[grid](
|
||||||
x,
|
x,
|
||||||
layer.moe_ffn1_weight,
|
layer.moe_ffn1_weight,
|
||||||
intermediate_cache1,
|
ffn1_out,
|
||||||
None,
|
None,
|
||||||
layer.moe_ffn1_weight_scale,
|
layer.moe_ffn1_weight_scale,
|
||||||
None,
|
None,
|
||||||
@@ -174,8 +166,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||||
stride_cm=intermediate_cache1.strides[0],
|
stride_cm=ffn1_out.strides[0],
|
||||||
stride_cn=intermediate_cache1.strides[1],
|
stride_cn=ffn1_out.strides[1],
|
||||||
#
|
#
|
||||||
stride_asm=-1,
|
stride_asm=-1,
|
||||||
stride_ask=-1,
|
stride_ask=-1,
|
||||||
@@ -197,16 +189,21 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
ffn2_input = paddle.incubate.nn.functional.swiglu(
|
||||||
intermediate_cache1)
|
ffn1_out)
|
||||||
|
|
||||||
|
ffn2_out = paddle.empty(
|
||||||
|
(token_num * top_k, hidden_size),
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
grid = (
|
grid = (
|
||||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||||
fused_moe_kernel_paddle[grid](
|
fused_moe_kernel_paddle[grid](
|
||||||
intermediate_cache2,
|
ffn2_input,
|
||||||
layer.moe_ffn2_weight,
|
layer.moe_ffn2_weight,
|
||||||
intermediate_cache3,
|
ffn2_out,
|
||||||
None,
|
None,
|
||||||
layer.moe_ffn2_weight_scale,
|
layer.moe_ffn2_weight_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
@@ -217,13 +214,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
token_num * top_k,
|
token_num * top_k,
|
||||||
N=hidden_size,
|
N=hidden_size,
|
||||||
K=moe_intermediate_size,
|
K=moe_intermediate_size,
|
||||||
stride_am=intermediate_cache2.strides[0],
|
stride_am=ffn2_input.strides[0],
|
||||||
stride_ak=intermediate_cache2.strides[1],
|
stride_ak=ffn2_input.strides[1],
|
||||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||||
stride_cm=intermediate_cache3.strides[0],
|
stride_cm=ffn2_out.strides[0],
|
||||||
stride_cn=intermediate_cache3.strides[1],
|
stride_cn=ffn2_out.strides[1],
|
||||||
stride_asm=-1,
|
stride_asm=-1,
|
||||||
stride_ask=-1,
|
stride_ask=-1,
|
||||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||||
@@ -244,8 +241,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
ffn2_out.reshape_([token_num, top_k, hidden_size])
|
||||||
out = intermediate_cache3.sum(axis=1)
|
out = ffn2_out.sum(axis=1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -343,18 +340,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_cache1 = paddle.empty(
|
ffn1_out = paddle.empty(
|
||||||
[token_num * top_k, moe_intermediate_size * 2],
|
[token_num * top_k, moe_intermediate_size * 2],
|
||||||
dtype=x.dtype,
|
dtype=x.dtype,
|
||||||
)
|
)
|
||||||
intermediate_cache2 = paddle.empty(
|
|
||||||
(token_num * top_k, moe_intermediate_size),
|
|
||||||
dtype=x.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache3 = paddle.empty(
|
|
||||||
(token_num * top_k, hidden_size),
|
|
||||||
dtype=x.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
config_ffn1 = {
|
config_ffn1 = {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
@@ -381,7 +370,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
fused_moe_kernel_paddle[grid](
|
fused_moe_kernel_paddle[grid](
|
||||||
permute_x,
|
permute_x,
|
||||||
layer.moe_ffn1_weight,
|
layer.moe_ffn1_weight,
|
||||||
intermediate_cache1,
|
ffn1_out,
|
||||||
layer.moe_ffn1_in_scale,
|
layer.moe_ffn1_in_scale,
|
||||||
layer.moe_ffn1_weight_scale,
|
layer.moe_ffn1_weight_scale,
|
||||||
None,
|
None,
|
||||||
@@ -397,8 +386,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||||
stride_cm=intermediate_cache1.strides[0],
|
stride_cm=ffn1_out.strides[0],
|
||||||
stride_cn=intermediate_cache1.strides[1],
|
stride_cn=ffn1_out.strides[1],
|
||||||
#
|
#
|
||||||
stride_asm=-1, # only used in blockwise fp8
|
stride_asm=-1, # only used in blockwise fp8
|
||||||
stride_ask=-1, # only used in blockwise fp8
|
stride_ask=-1, # only used in blockwise fp8
|
||||||
@@ -420,11 +409,11 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
|
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
ffn2_input = paddle.incubate.nn.functional.swiglu(
|
||||||
intermediate_cache1)
|
ffn1_out)
|
||||||
|
|
||||||
intermediate_cache2 = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
||||||
intermediate_cache2,
|
ffn2_input,
|
||||||
scale=layer.moe_ffn2_in_scale,
|
scale=layer.moe_ffn2_in_scale,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -438,14 +427,19 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ffn2_out = paddle.empty(
|
||||||
|
(token_num * top_k, hidden_size),
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
grid = (
|
grid = (
|
||||||
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
|
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
|
||||||
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
|
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
|
||||||
|
|
||||||
fused_moe_kernel_paddle[grid](
|
fused_moe_kernel_paddle[grid](
|
||||||
intermediate_cache2,
|
ffn2_input,
|
||||||
layer.moe_ffn2_weight,
|
layer.moe_ffn2_weight,
|
||||||
intermediate_cache3,
|
ffn2_out,
|
||||||
layer.moe_ffn2_in_scale,
|
layer.moe_ffn2_in_scale,
|
||||||
layer.moe_ffn2_weight_scale,
|
layer.moe_ffn2_weight_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
@@ -456,13 +450,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
token_num * top_k,
|
token_num * top_k,
|
||||||
N=hidden_size,
|
N=hidden_size,
|
||||||
K=moe_intermediate_size,
|
K=moe_intermediate_size,
|
||||||
stride_am=intermediate_cache2.strides[0],
|
stride_am=ffn2_input.strides[0],
|
||||||
stride_ak=intermediate_cache2.strides[1],
|
stride_ak=ffn2_input.strides[1],
|
||||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||||
stride_cm=intermediate_cache3.strides[0],
|
stride_cm=ffn2_out.strides[0],
|
||||||
stride_cn=intermediate_cache3.strides[1],
|
stride_cn=ffn2_out.strides[1],
|
||||||
stride_asm=-1,
|
stride_asm=-1,
|
||||||
stride_ask=-1,
|
stride_ask=-1,
|
||||||
stride_bse=-1,
|
stride_bse=-1,
|
||||||
@@ -483,8 +477,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
|
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
ffn2_out.reshape_([token_num, top_k, hidden_size])
|
||||||
out = intermediate_cache3.sum(axis=1)
|
out = ffn2_out.sum(axis=1)
|
||||||
|
|
||||||
if layer.tp_size > 1:
|
if layer.tp_size > 1:
|
||||||
tensor_model_parallel_all_reduce(out)
|
tensor_model_parallel_all_reduce(out)
|
||||||
|
Reference in New Issue
Block a user