From 2513cd929bffa93f17daf2cb33118ea5d35add89 Mon Sep 17 00:00:00 2001 From: Yuan Xiaolan <845594810@qq.com> Date: Wed, 13 Aug 2025 21:41:34 +0800 Subject: [PATCH] support w4afp8 EP inference (#3382) --- custom_ops/gpu_ops/cpp_extensions.cc | 3 +- custom_ops/gpu_ops/helper.h | 6 + custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu | 69 ++++- .../gpu_ops/moe/fast_hardamard_kernel.cu | 24 +- custom_ops/gpu_ops/moe/fused_moe_op.h | 279 ++++++++++++++++++ custom_ops/gpu_ops/moe/moe_ffn.cu | 147 +++++++-- custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h | 10 +- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu | 174 +++++++++-- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h | 47 +++ .../w4afp8_gemm/w4afp8_gemm_kernel.hpp | 10 +- .../utils/auto_gen_w4afp8_gemm_kernel.py | 10 +- .../model_executor/layers/moe/__init__.py | 7 +- fastdeploy/model_executor/layers/moe/ep.py | 3 +- .../layers/moe/fused_moe_cutlass_backend.py | 224 +++++++++++++- .../layers/quantization/w4afp8.py | 17 +- .../model_executor/models/ernie4_5_moe.py | 2 +- test/operators/test_w4afp8_gemm.py | 12 +- 17 files changed, 944 insertions(+), 100 deletions(-) create mode 100644 custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 266d50599..fc6177285 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -188,7 +188,8 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::optional& down_proj_scale, const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, - const std::string& quant_method, const bool used_in_ep_low_latency); + const std::string& quant_method, const bool used_in_ep_low_latency, + const int estimate_total_token_nums); paddle::Tensor MoeExpertFFNWint2Func( const paddle::Tensor& permute_input, diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index ed4efe927..dd65b8a4b 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -193,6 +193,12 @@ public: typedef uint8_t data_t; }; +template <> class PDTraits { +public: + typedef __nv_fp8_e4m3 DataType; + typedef paddle::float8_e4m3fn data_t; +}; + template struct alignas(sizeof(T) * Size) AlignedVector { T val[Size]; diff --git a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu index 60ae7d1fc..d521347e3 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu @@ -269,7 +269,7 @@ std::vector EPMoeExpertCombine( } -template +template __global__ void permute_x_kernel(const T *src_x, const int64_t *topk_idx, const float *topk_weights, @@ -285,9 +285,9 @@ __global__ void permute_x_kernel(const T *src_x, int *dst_indices, int *cumsum_idx_gpu, int64_t *token_nums_per_expert_cumsum, - int64_t *expert_idx_per_token, + int64_t *expert_idx_per_token, // [num_rows, moe_topk] float max_bound = 127.0, - float min_bound = -127.0) { // [num_rows, moe_topk] + float min_bound = -127.0) { const int src_token_idx = blockIdx.x; const int tid = threadIdx.x; constexpr int vec_size = sizeof(int4) / sizeof(T); @@ -330,10 +330,17 @@ __global__ void permute_x_kernel(const T *src_x, if (up_gate_proj_in_scale) { for (int i = 0; i < vec_size; i++) { float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast(src_vec[i]); - if (RoundType == 0) { - res_vec[i] = static_cast(ClipFunc(rint(quant_value), min_bound, max_bound)); + if constexpr (std::is_same::value) { + // w4aint8 + if (RoundType == 0) { + res_vec[i] = static_cast(ClipFunc(rint(quant_value), min_bound, max_bound)); + } else { + res_vec[i] = static_cast(ClipFunc(round(quant_value), min_bound, max_bound)); + } } else { - res_vec[i] = static_cast(round(quant_value)); + // w4afp8 + float value = ClipFunc(quant_value, min_bound, max_bound); + res_vec[i] = static_cast(value); } } } else { @@ -373,6 +380,10 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; + typedef PDTraits traits_fp8; + typedef typename traits_fp8::DataType DataType_fp8; + typedef typename traits_fp8::data_t data_t_fp8; + auto stream = input.stream(); auto place = input.place(); const int gridx = min(132 * 8, num_rows); @@ -420,6 +431,50 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, -127.0 ); } + } else if (moe_quant_type == "w4afp8") { + if (num_experts_per_rank == 8) { + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 448.0f, + -448.0f + ); + } else if (num_experts_per_rank == 16) { + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 448.0f, + -448.0f + ); + } } else { if (num_experts_per_rank == 8) { permute_x_kernel<<>>( @@ -493,7 +548,7 @@ std::vector EPMoeExpertDispatch( auto permute_input = GetEmptyTensor( {token_nums_this_rank, hidden_size}, - moe_quant_type == "w4a8" ? paddle::DataType::INT8 : input_type, + moe_quant_type == "w4a8" ? paddle::DataType::INT8 : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN : input_type, place); auto num_experts_per_rank_tensor = GetEmptyTensor( {num_experts_per_rank}, diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu index 7bf46f0f4..63b45b743 100644 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu +++ b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu @@ -88,7 +88,7 @@ struct nv_type_traits { constexpr int kLogN = 7; \ __VA_ARGS__ \ } else { \ - PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupport!", logN)); \ + PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \ } #define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \ @@ -108,7 +108,7 @@ struct nv_type_traits { constexpr int VEC_SIZE = 1; \ __VA_ARGS__ \ } else { \ - PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupport!", vec_size)); \ + PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", vec_size)); \ } #define DISPATCH_logN(logN, kLogN, ...) \ @@ -605,26 +605,6 @@ void moe_fast_hardamard_kernel(const T *x, exchange_smem_pre(x_vals, smem_exchange); } if constexpr (kNChunks > 1) { -// T x_vals_transposed[VecSize][kNChunks] = {init_value}; -// #pragma unroll -// for (int c = 0; c < kNChunks; ++c) { -// #pragma unroll -// for (int i = 0; i < VecSize; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; } -// } -// if constexpr (kNChunks == 28) { -// hadamard_mult_thread_chunk_28(x_vals_transposed); -// } else if constexpr (kNChunks == 36) { -// hadamard_mult_thread_chunk_36(x_vals_transposed); -// } else { -// constexpr int kLogNChunks = cilog2(kNChunks); -// static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); -// hadamard_mult_thread(x_vals_transposed); -// } -// #pragma unroll -// for (int c = 0; c < kNChunks; ++c) { -// #pragma unroll -// for (int i = 0; i < VecSize; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; } -// } if constexpr (kNChunks == 28) { hadamard_mult_thread_28_transpose(x_vals); } else if constexpr (kNChunks == 36) { diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index 09d705d41..3f82bec82 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -72,6 +72,285 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) return u; } +struct uint8 { + uint4 u; + uint4 v; +}; + +template struct BytesToType {}; + +template<> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +template