diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 671427005..ca4223ffe 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -192,7 +192,8 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::optional& down_proj_scale, const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, - const std::string& quant_method, const bool used_in_ep_low_latency); + const std::string& quant_method, const bool used_in_ep_low_latency, + const int estimate_total_token_nums); paddle::Tensor MoeExpertFFNWint2Func( const paddle::Tensor& permute_input, diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 468aff1fc..8256d43cd 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -193,6 +193,12 @@ public: typedef uint8_t data_t; }; +template <> class PDTraits { +public: + typedef __nv_fp8_e4m3 DataType; + typedef paddle::float8_e4m3fn data_t; +}; + template struct alignas(sizeof(T) * Size) AlignedVector { T val[Size]; diff --git a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu index d677b360c..93ba97ef9 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu @@ -314,7 +314,7 @@ std::vector EPMoeExpertCombine( } -template +template __global__ void permute_x_kernel(const T *src_x, const int64_t *topk_idx, const float *topk_weights, @@ -330,9 +330,9 @@ __global__ void permute_x_kernel(const T *src_x, int *dst_indices, int *cumsum_idx_gpu, int64_t *token_nums_per_expert_cumsum, - int64_t *expert_idx_per_token, + int64_t *expert_idx_per_token, // [num_rows, moe_topk] float max_bound = 127.0, - float min_bound = -127.0) { // [num_rows, moe_topk] + float min_bound = -127.0) { const int src_token_idx = blockIdx.x; const int tid = threadIdx.x; constexpr int vec_size = sizeof(int4) / sizeof(T); @@ -375,10 +375,17 @@ __global__ void permute_x_kernel(const T *src_x, if (up_gate_proj_in_scale) { for (int i = 0; i < vec_size; i++) { float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast(src_vec[i]); - if (RoundType == 0) { - res_vec[i] = static_cast(ClipFunc(rint(quant_value), min_bound, max_bound)); + if constexpr (std::is_same::value) { + // w4aint8 + if (RoundType == 0) { + res_vec[i] = static_cast(ClipFunc(rint(quant_value), min_bound, max_bound)); + } else { + res_vec[i] = static_cast(ClipFunc(round(quant_value), min_bound, max_bound)); + } } else { - res_vec[i] = static_cast(round(quant_value)); + // w4afp8 + float value = ClipFunc(quant_value, min_bound, max_bound); + res_vec[i] = static_cast(value); } } } else { @@ -418,6 +425,10 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; + typedef PDTraits traits_fp8; + typedef typename traits_fp8::DataType DataType_fp8; + typedef typename traits_fp8::data_t data_t_fp8; + auto stream = input.stream(); auto place = input.place(); const int gridx = min(132 * 8, num_rows); @@ -465,6 +476,50 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, -127.0 ); } + } else if (moe_quant_type == "w4afp8") { + if (num_experts_per_rank == 8) { + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 448.0f, + -448.0f + ); + } else if (num_experts_per_rank == 16) { + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 448.0f, + -448.0f + ); + } } else { if (num_experts_per_rank == 8) { permute_x_kernel<<>>( @@ -538,7 +593,7 @@ std::vector EPMoeExpertDispatch( auto permute_input = GetEmptyTensor( {token_nums_this_rank, hidden_size}, - moe_quant_type == "w4a8" ? paddle::DataType::INT8 : input_type, + moe_quant_type == "w4a8" ? paddle::DataType::INT8 : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN : input_type, place); auto num_experts_per_rank_tensor = GetEmptyTensor( {num_experts_per_rank}, diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu index 7bf46f0f4..63b45b743 100644 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu +++ b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu @@ -88,7 +88,7 @@ struct nv_type_traits { constexpr int kLogN = 7; \ __VA_ARGS__ \ } else { \ - PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupport!", logN)); \ + PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \ } #define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \ @@ -108,7 +108,7 @@ struct nv_type_traits { constexpr int VEC_SIZE = 1; \ __VA_ARGS__ \ } else { \ - PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupport!", vec_size)); \ + PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", vec_size)); \ } #define DISPATCH_logN(logN, kLogN, ...) \ @@ -605,26 +605,6 @@ void moe_fast_hardamard_kernel(const T *x, exchange_smem_pre(x_vals, smem_exchange); } if constexpr (kNChunks > 1) { -// T x_vals_transposed[VecSize][kNChunks] = {init_value}; -// #pragma unroll -// for (int c = 0; c < kNChunks; ++c) { -// #pragma unroll -// for (int i = 0; i < VecSize; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; } -// } -// if constexpr (kNChunks == 28) { -// hadamard_mult_thread_chunk_28(x_vals_transposed); -// } else if constexpr (kNChunks == 36) { -// hadamard_mult_thread_chunk_36(x_vals_transposed); -// } else { -// constexpr int kLogNChunks = cilog2(kNChunks); -// static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); -// hadamard_mult_thread(x_vals_transposed); -// } -// #pragma unroll -// for (int c = 0; c < kNChunks; ++c) { -// #pragma unroll -// for (int i = 0; i < VecSize; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; } -// } if constexpr (kNChunks == 28) { hadamard_mult_thread_28_transpose(x_vals); } else if constexpr (kNChunks == 36) { diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index efe5b26bc..3764509ff 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -72,6 +72,287 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) return u; } +struct uint8 { + uint4 u; + uint4 v; +}; + +template struct BytesToType {}; + +template<> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +template