[Optimize] Optimize tensorwise fp8 performance (#2729)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [Optimize] Optimize tensorwise fp8 performance
This commit is contained in:
ming1753
2025-07-07 20:06:28 +08:00
committed by GitHub
parent 1b54a2831e
commit ef6649a577
6 changed files with 318 additions and 88 deletions

View File

@@ -468,6 +468,28 @@ std::vector<paddle::Tensor> NoauxTc(
int topk,
float routed_scaling_factor);
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
bool trans_x,
bool trans_y,
float scale, // only support per-tensor quantization
std::string output_dtype,
std::string activation_type);
paddle::Tensor MoeFusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled);
paddle::Tensor FusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const float scale);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -697,38 +719,21 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
"text_image_gather_scatter function");
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
py::arg("a"),
py::arg("c_or_none"),
py::arg("b_q_weight"),
py::arg("b_scales"),
py::arg("global_scale_or_none"),
py::arg("b_zeros_or_none"),
py::arg("g_idx_or_none"),
py::arg("perm_or_none"),
py::arg("workspace"),
py::arg("sorted_token_ids"),
py::arg("expert_ids"),
py::arg("num_tokens_post_padded"),
py::arg("topk_weights"),
py::arg("moe_block_size"),
py::arg("top_k"),
py::arg("mul_topk_weights"),
py::arg("is_ep"),
py::arg("b_q_type_str"),
py::arg("size_m"),
py::arg("size_n"),
py::arg("size_k"),
py::arg("is_k_full"),
py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"),
py::arg("is_zp_float"));
py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"),
py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"),
py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"),
py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"),
py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"),
py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"), py::arg("is_zp_float"));
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function");
/**
* cutlass_scaled_mm.cu
* cutlass_scaled_mm
@@ -753,6 +758,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
"dynamic_per_token_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
@@ -762,4 +768,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
}