diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 97c212691..b4eeb1fc7 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -468,6 +468,7 @@ std::vector NoauxTc( int topk, float routed_scaling_factor); +#ifdef ENABLE_FP8 paddle::Tensor cutlass_fp8_fp8_half_gemm_func( const paddle::Tensor& x, const paddle::Tensor& y, @@ -489,6 +490,7 @@ paddle::Tensor MoeFusedHadamardQuantFp8Func( paddle::Tensor FusedHadamardQuantFp8Func( const paddle::Tensor &input, const float scale); +#endif PYBIND11_MODULE(fastdeploy_ops, m) { @@ -769,6 +771,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); +#ifdef ENABLE_FP8 m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func, py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"), py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"), @@ -780,4 +783,5 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func, py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function"); +#endif } diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 801953087..caf395c5d 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -129,18 +129,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): True, # apply_norm_weight, False, ) - intermediate_cache1 = paddle.empty( + ffn1_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, ) - intermediate_cache2 = paddle.empty( - (token_num * top_k, moe_intermediate_size), - dtype=x.dtype, - ) - intermediate_cache3 = paddle.empty( - (token_num * top_k, hidden_size), - dtype=x.dtype, - ) config = { "BLOCK_SIZE_M": 32, @@ -158,7 +150,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( x, layer.moe_ffn1_weight, - intermediate_cache1, + ffn1_out, None, layer.moe_ffn1_weight_scale, None, @@ -174,8 +166,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): stride_be=layer.moe_ffn1_weight.strides[0], stride_bk=layer.moe_ffn1_weight.strides[1], stride_bn=layer.moe_ffn1_weight.strides[2], - stride_cm=intermediate_cache1.strides[0], - stride_cn=intermediate_cache1.strides[1], + stride_cm=ffn1_out.strides[0], + stride_cn=ffn1_out.strides[1], # stride_asm=-1, stride_ask=-1, @@ -197,16 +189,21 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, ) - intermediate_cache2 = paddle.incubate.nn.functional.swiglu( - intermediate_cache1) + ffn2_input = paddle.incubate.nn.functional.swiglu( + ffn1_out) + + ffn2_out = paddle.empty( + (token_num * top_k, hidden_size), + dtype=x.dtype, + ) grid = ( ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( - intermediate_cache2, + ffn2_input, layer.moe_ffn2_weight, - intermediate_cache3, + ffn2_out, None, layer.moe_ffn2_weight_scale, topk_weights, @@ -217,13 +214,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): token_num * top_k, N=hidden_size, K=moe_intermediate_size, - stride_am=intermediate_cache2.strides[0], - stride_ak=intermediate_cache2.strides[1], + stride_am=ffn2_input.strides[0], + stride_ak=ffn2_input.strides[1], stride_be=layer.moe_ffn2_weight.strides[0], stride_bk=layer.moe_ffn2_weight.strides[1], stride_bn=layer.moe_ffn2_weight.strides[2], - stride_cm=intermediate_cache3.strides[0], - stride_cn=intermediate_cache3.strides[1], + stride_cm=ffn2_out.strides[0], + stride_cn=ffn2_out.strides[1], stride_asm=-1, stride_ask=-1, stride_bse=layer.moe_ffn2_weight_scale.strides[0], @@ -244,8 +241,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, ) - intermediate_cache3.reshape_([token_num, top_k, hidden_size]) - out = intermediate_cache3.sum(axis=1) + ffn2_out.reshape_([token_num, top_k, hidden_size]) + out = ffn2_out.sum(axis=1) return out @@ -343,18 +340,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): False, ) - intermediate_cache1 = paddle.empty( + ffn1_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, ) - intermediate_cache2 = paddle.empty( - (token_num * top_k, moe_intermediate_size), - dtype=x.dtype, - ) - intermediate_cache3 = paddle.empty( - (token_num * top_k, hidden_size), - dtype=x.dtype, - ) config_ffn1 = { "BLOCK_SIZE_M": 32, @@ -381,7 +370,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( permute_x, layer.moe_ffn1_weight, - intermediate_cache1, + ffn1_out, layer.moe_ffn1_in_scale, layer.moe_ffn1_weight_scale, None, @@ -397,8 +386,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): stride_be=layer.moe_ffn1_weight.strides[0], stride_bk=layer.moe_ffn1_weight.strides[1], stride_bn=layer.moe_ffn1_weight.strides[2], - stride_cm=intermediate_cache1.strides[0], - stride_cn=intermediate_cache1.strides[1], + stride_cm=ffn1_out.strides[0], + stride_cn=ffn1_out.strides[1], # stride_asm=-1, # only used in blockwise fp8 stride_ask=-1, # only used in blockwise fp8 @@ -420,11 +409,11 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0, ) - intermediate_cache2 = paddle.incubate.nn.functional.swiglu( - intermediate_cache1) + ffn2_input = paddle.incubate.nn.functional.swiglu( + ffn1_out) - intermediate_cache2 = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8( - intermediate_cache2, + ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8( + ffn2_input, scale=layer.moe_ffn2_in_scale, topk_ids=topk_ids, top_k=top_k, @@ -438,14 +427,19 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): "GROUP_SIZE_M": 1, } + ffn2_out = paddle.empty( + (token_num * top_k, hidden_size), + dtype=x.dtype, + ) + grid = ( ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( - intermediate_cache2, + ffn2_input, layer.moe_ffn2_weight, - intermediate_cache3, + ffn2_out, layer.moe_ffn2_in_scale, layer.moe_ffn2_weight_scale, topk_weights, @@ -456,13 +450,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): token_num * top_k, N=hidden_size, K=moe_intermediate_size, - stride_am=intermediate_cache2.strides[0], - stride_ak=intermediate_cache2.strides[1], + stride_am=ffn2_input.strides[0], + stride_ak=ffn2_input.strides[1], stride_be=layer.moe_ffn2_weight.strides[0], stride_bk=layer.moe_ffn2_weight.strides[1], stride_bn=layer.moe_ffn2_weight.strides[2], - stride_cm=intermediate_cache3.strides[0], - stride_cn=intermediate_cache3.strides[1], + stride_cm=ffn2_out.strides[0], + stride_cn=ffn2_out.strides[1], stride_asm=-1, stride_ask=-1, stride_bse=-1, @@ -483,8 +477,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0, ) - intermediate_cache3.reshape_([token_num, top_k, hidden_size]) - out = intermediate_cache3.sum(axis=1) + ffn2_out.reshape_([token_num, top_k, hidden_size]) + out = ffn2_out.sum(axis=1) if layer.tp_size > 1: tensor_model_parallel_all_reduce(out)