diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index f62052c79..182a59148 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -151,7 +151,7 @@ std::vector MoeExpertDispatch( const paddle::Tensor &input, const paddle::Tensor &gating_output, const paddle::optional &gating_correction_bias, const paddle::optional &w4a8_in_scale, const int moe_topk, - const bool group_moe, const bool topk_only_mode); + const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode); std::vector MoETopKSelectKernel(const paddle::Tensor &gating_logits, @@ -912,7 +912,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"), py::arg("gating_output"), py::arg("gating_correction_bias"), py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"), - py::arg("topk_only_mode"), "moe export dispatch function"); + py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function"); /** * moe/fused_moe/ep_moe_prefill_func.cu diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index 3764509ff..2abaae5dd 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -1296,6 +1296,18 @@ __global__ void initialize_moe_routing_kernel( dest_vec[j] = static_cast(round(quant_value)); } Store(dest_vec, &dest_row_ptr[tid]); + } else if constexpr (std::is_same::value) { + using StoreT = AlignedVector; + StoreT dest_vec; + const float max_bound = 448.f; + const float min_bound = -448.f; + for (int j = 0; j < VecSize; j++) { + float quant_value = max_bound * scale * static_cast(src_vec[j]); + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + dest_vec[j] = static_cast(quant_value); + } + Store(dest_vec, &dest_row_ptr[tid]); } else { Store(src_vec, &dest_row_ptr[tid]); } diff --git a/custom_ops/gpu_ops/moe/moe_dispatch.cu b/custom_ops/gpu_ops/moe/moe_dispatch.cu index 7ae20e0ae..85bad95cd 100644 --- a/custom_ops/gpu_ops/moe/moe_dispatch.cu +++ b/custom_ops/gpu_ops/moe/moe_dispatch.cu @@ -113,11 +113,20 @@ void MoeDispatchKernel( permuted_rows_, moe_topk * num_rows, false, stream); if (w4a8_in_scale) { - initialize_moe_routing_kernelLauncher::run( + if (permute_input->dtype() == paddle::DataType::INT8) { + initialize_moe_routing_kernelLauncher::run( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), permute_indices_per_token->data(), num_rows, num_rows, hidden_size, moe_topk, stream); + } else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) { + initialize_moe_routing_kernelLauncher::run( + input.data(), permute_input->data(), + permuted_rows_, expert_idx_per_token->data(), + w4a8_in_scale->data(), + permute_indices_per_token->data(), num_rows, num_rows, + hidden_size, moe_topk, stream); + } } else { initialize_moe_routing_kernelLauncher::run( input.data(), permute_input->data(), permuted_rows_, @@ -135,7 +144,7 @@ std::vector MoeExpertDispatch( const paddle::Tensor &input, const paddle::Tensor &gating_output, const paddle::optional &gating_correction_bias, const paddle::optional &w4a8_in_scale, const int moe_topk, - const bool group_moe, const bool topk_only_mode) { + const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) { const auto input_type = input.dtype(); auto place = input.place(); int token_rows = 0; @@ -151,8 +160,14 @@ std::vector MoeExpertDispatch( const int num_rows = token_rows; const int hidden_size = input.dims()[input_dims.size() - 1]; - auto permute_input_dtype = - w4a8_in_scale ? paddle::DataType::INT8 : input_type; + auto permute_input_dtype = input_type; + if (w4a8_in_scale) { + if (moe_quant_type == "w4a8") { + permute_input_dtype = paddle::DataType::INT8; + } else if (moe_quant_type == "w4afp8") { + permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN; + } + } auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size}, permute_input_dtype, place); @@ -285,7 +300,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch) .Outputs({"permute_input", "tokens_expert_prefix_sum", "permute_indices_per_token", "topk_weight", "topk_idx", "expert_idx_per_token"}) - .Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"}) + .Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"}) .SetKernelFn(PD_KERNEL(MoeExpertDispatch)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype)); diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index 6a748b4a7..117f1c63e 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -204,7 +204,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, ->data(), reinterpret_cast(fc1_out), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, - num_max_tokens_per_expert, + used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0], num_experts, inter_size, hidden_size, @@ -369,7 +369,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, ->data(), reinterpret_cast(ffn_out_data), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, - num_max_tokens_per_expert, + used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0], num_experts, hidden_size, inter_size / 2, diff --git a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h index 97cbf64d7..cb46397d5 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h @@ -225,20 +225,9 @@ struct CollectiveMainloopFwd { const int actual_token, const int bidn) const { - auto g_offset = local_tile( - mB(_, _, 0), - cute::make_shape(1, size<1>(mB)), - make_coord(pre_fix_token, _0{})); - - auto g_tensor = make_tensor( - g_offset.data(), - make_layout( - cute::make_shape(actual_token, size<2>(mB)), - g_offset.stride() - )); + auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0)); Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); - return gB; } diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp index 55bf92757..01a8dd114 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp @@ -222,7 +222,7 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl static_cast(A), get_gmem_layout(M, K / 2), static_cast(B), - get_gmem_layout(TokenPackSize == 0 ? max_tokens * Batch : TokenPackSize, K), + get_gmem_layout(TokenPackSize == 0 ? max_tokens: TokenPackSize, K), static_cast(C), get_gmem_layout(M, TokenPackSize == 0 ? max_tokens : TokenPackSize), weight_scale, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index d111a0426..fa08b4a78 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -276,6 +276,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): ), # if set, permute_input will be int8_t layer.top_k, False, + self.moe_quant_type, topk_only_mode=True, ) else: @@ -295,6 +296,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): ), # if set, permute_input will be int8_t layer.top_k, False, + self.moe_quant_type, topk_only_mode=False, ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index f592afc68..f9f717d31 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -284,6 +284,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): ), # if set, permute_input will be int8_t layer.top_k, False, + self.moe_quant_type, topk_only_mode=False, )