[Optimize] Optimize tensorwise fp8 performance (#2729)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [Optimize] Optimize tensorwise fp8 performance
This commit is contained in:
ming1753
2025-07-07 20:06:28 +08:00
committed by GitHub
parent 1b54a2831e
commit ef6649a577
6 changed files with 318 additions and 88 deletions

View File

@@ -468,6 +468,28 @@ std::vector<paddle::Tensor> NoauxTc(
int topk, int topk,
float routed_scaling_factor); float routed_scaling_factor);
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
bool trans_x,
bool trans_y,
float scale, // only support per-tensor quantization
std::string output_dtype,
std::string activation_type);
paddle::Tensor MoeFusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled);
paddle::Tensor FusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const float scale);
PYBIND11_MODULE(fastdeploy_ops, m) { PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -697,38 +719,21 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
"text_image_gather_scatter function"); "text_image_gather_scatter function");
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func); m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel); m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi, m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
py::arg("a"), py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"),
py::arg("c_or_none"), py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"),
py::arg("b_q_weight"), py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"),
py::arg("b_scales"), py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"),
py::arg("global_scale_or_none"), py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"),
py::arg("b_zeros_or_none"), py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"),
py::arg("g_idx_or_none"), py::arg("use_fp32_reduce"), py::arg("is_zp_float"));
py::arg("perm_or_none"),
py::arg("workspace"),
py::arg("sorted_token_ids"),
py::arg("expert_ids"),
py::arg("num_tokens_post_padded"),
py::arg("topk_weights"),
py::arg("moe_block_size"),
py::arg("top_k"),
py::arg("mul_topk_weights"),
py::arg("is_ep"),
py::arg("b_q_type_str"),
py::arg("size_m"),
py::arg("size_n"),
py::arg("size_k"),
py::arg("is_k_full"),
py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"),
py::arg("is_zp_float"));
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch, m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function"); "get_position_ids_and_mask_encoder_batch function");
/** /**
* cutlass_scaled_mm.cu * cutlass_scaled_mm.cu
* cutlass_scaled_mm * cutlass_scaled_mm
@@ -753,6 +758,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant, m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
"dynamic_per_token_scaled_fp8_quant function", "dynamic_per_token_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub")); py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function"); m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function"); m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
@@ -762,4 +768,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function"); m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
} }

View File

@@ -19,7 +19,7 @@
#include "fp8_fp8_half_cuda_core_gemm.h" #include "fp8_fp8_half_cuda_core_gemm.h"
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm( paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
const paddle::Tensor& x, const paddle::Tensor& x,
const paddle::Tensor& y, const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias, const paddle::optional<paddle::Tensor>& bias,
@@ -174,7 +174,21 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
fuse_gemm_config}; fuse_gemm_config};
fp8_fp8_gemm_scale_bias_act(params); fp8_fp8_gemm_scale_bias_act(params);
} }
return {out}; return out;
}
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
bool trans_x,
bool trans_y,
float scale, // only support per-tensor quantization
std::string output_dtype,
std::string activation_type) {
return {cutlass_fp8_fp8_half_gemm_func(
x, y, bias, trans_x, trans_y, scale,
output_dtype, activation_type)};
} }
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape( std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape(

View File

@@ -0,0 +1,198 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include "helper.h"
__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) {
int lane_id = threadIdx.x % 32;
#pragma unroll
for (int step = 0; step < 5; ++step) {
const int lane_mask = 1 << step;
const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f;
__nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask);
x = sign * x + x_val_other;
}
}
__global__ void MoeFusedHadamardQuantFp8Kernel(
const __nv_bfloat16* __restrict__ input,
const float* __restrict__ scale,
const int64_t* __restrict__ topk_ids,
__nv_fp8_e4m3* out,
const int top_k,
const int intermediate_size,
const int64_t numel
) {
int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (out_idx >= numel) return;
int64_t token_idx = out_idx / (top_k * intermediate_size);
int64_t topk_idx = (out_idx / intermediate_size) % top_k;
int64_t inter_idx = out_idx % intermediate_size;
int64_t input_idx = token_idx * intermediate_size + inter_idx;
if (input_idx >= numel / top_k) return;
int64_t expert_id = topk_ids[token_idx * top_k + topk_idx];
float scale_value = scale[expert_id];
__nv_bfloat16 x = input[input_idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale_value;
out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
__global__ void MoeFusedHadamardQuantFp8TiledKernel(
const __nv_bfloat16* __restrict__ input,
const float* __restrict__ scale,
const int64_t* __restrict__ topk_ids,
__nv_fp8_e4m3* out,
const int top_k,
const int intermediate_size,
const int64_t numel
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= numel) return;
int64_t token_idx = idx / intermediate_size;
int64_t expert_id = topk_ids[token_idx];
float scale_value = scale[expert_id];
__nv_bfloat16 x = input[idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale_value;
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
std::vector<paddle::Tensor> MoeFusedHadamardQuantFp8(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled) {
int64_t numel = input.numel();
if (!tiled) numel *= top_k;
paddle::Tensor out = GetEmptyTensor(
{numel / intermediate_size, intermediate_size},
paddle::DataType::FLOAT8_E4M3FN,
input.place());
constexpr int64_t thread_per_block = 256;
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
auto stream = input.stream();
if (tiled) {
MoeFusedHadamardQuantFp8TiledKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
scale.data<float>(),
topk_ids.data<int64_t>(),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
top_k,
intermediate_size,
numel
);
} else {
MoeFusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<phi::dtype::bfloat16>()),
scale.data<float>(),
topk_ids.data<int64_t>(),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
top_k,
intermediate_size,
numel
);
}
return {out};
}
PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8)
.Inputs({"input", "scale", "topk_ids"})
.Outputs({"output"})
.Attrs({"top_k: int",
"intermediate_size: int",
"tiled: bool"})
.SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8));
paddle::Tensor MoeFusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled) {
return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0];
}
__global__ void FusedHadamardQuantFp8Kernel(
const __nv_bfloat16* __restrict__ input,
__nv_fp8_e4m3* out,
const float scale,
const int64_t numel) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= numel) return;
__nv_bfloat16 x = input[idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale;
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
std::vector<paddle::Tensor> FusedHadamardQuantFp8(
const paddle::Tensor &input,
const float scale) {
int64_t numel = input.numel();
paddle::Tensor out = GetEmptyTensor(
input.dims(),
paddle::DataType::FLOAT8_E4M3FN,
input.place());
constexpr int64_t thread_per_block = 256;
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
auto stream = input.stream();
FusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
scale,
numel
);
return {out};
}
PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8)
.Inputs({"input"})
.Outputs({"output"})
.Attrs({"scale: float"})
.SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8));
paddle::Tensor FusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const float scale) {
return FusedHadamardQuantFp8(input, scale)[0];
}

View File

@@ -442,6 +442,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu", "gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu", "gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu", "gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
"gpu_ops/fused_hadamard_quant_fp8.cu"
] ]
sources += find_end_files(fp8_auto_gen_directory, ".cu") sources += find_end_files(fp8_auto_gen_directory, ".cu")

View File

@@ -1,5 +1,5 @@
""" """
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -20,8 +20,7 @@ from paddle import nn
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication_op import \ from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map, from fastdeploy.model_executor.layers.utils import get_tensor
get_tensor)
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
@@ -272,8 +271,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
layer.moe_intermediate_size, layer.hidden_size layer.moe_intermediate_size, layer.hidden_size
] ]
ffn1_tensor = paddle.stack(ffn1_tensor, axis=0) ffn1_tensor = paddle.stack(ffn1_tensor, axis=0).view(paddle.float8_e4m3fn)
ffn2_tensor = paddle.stack(ffn2_tensor, axis=0) ffn2_tensor = paddle.stack(ffn2_tensor, axis=0).view(paddle.float8_e4m3fn)
added_wfp8afp8_attrs = [ added_wfp8afp8_attrs = [
"moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale", "moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale",
@@ -309,6 +308,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
dtype=weight_tensor.dtype, dtype=weight_tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
)) ))
if weight_tensor.dtype == paddle.float8_e4m3fn:
getattr(layer, name).copy_(weight_tensor, False)
else:
getattr(layer, name).set_value(weight_tensor) getattr(layer, name).set_value(weight_tensor)
def create_weights(self, layer: nn.Layer, state_dict): def create_weights(self, layer: nn.Layer, state_dict):
@@ -333,13 +335,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
scores = paddle.nn.functional.softmax(gate_out, axis=-1) topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
topk_weights, topk_ids = paddle.topk(scores, layer.gate_correction_bias,
k=top_k, top_k,
axis=-1, True, # apply_norm_weight,
sorted=False) False,
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True) )
intermediate_cache1 = paddle.empty( intermediate_cache1 = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2], [token_num * top_k, moe_intermediate_size * 2],
@@ -354,34 +356,31 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
dtype=x.dtype, dtype=x.dtype,
) )
config = { config_ffn1 = {
"BLOCK_SIZE_M": 32, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
} }
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) topk_ids, num_local_experts, config_ffn1["BLOCK_SIZE_M"])
max_possible_num_post_padded = sorted_token_ids.shape[0] max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = ( grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * ceil_div(max_possible_num_post_padded, config_ffn1["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) ceil_div(moe_intermediate_size * 2, config_ffn1["BLOCK_SIZE_N"]), )
adamard_matrix = create_hadamard_matrix_map[hidden_size] permute_x = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
x = paddle.matmul(x.cast("float32"), adamard_matrix) x,
scale=layer.moe_ffn1_in_scale,
permute_x = x[:, None, :].tile([1, top_k, 1]) topk_ids=topk_ids,
permute_x = permute_x.reshape([-1, hidden_size]) top_k=top_k,
intermediate_size=hidden_size,
quant_activation_scale = layer.moe_ffn1_in_scale[topk_ids].reshape( tiled=False)
[-1, 1])
permute_x = permute_x / quant_activation_scale
permute_x = permute_x.astype("float8_e4m3fn")
fused_moe_kernel_paddle[grid]( fused_moe_kernel_paddle[grid](
permute_x, permute_x,
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn), layer.moe_ffn1_weight,
intermediate_cache1, intermediate_cache1,
layer.moe_ffn1_in_scale, layer.moe_ffn1_in_scale,
layer.moe_ffn1_weight_scale, layer.moe_ffn1_weight_scale,
@@ -409,36 +408,43 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
group_n=-1, group_n=-1,
group_k=-1, group_k=-1,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"], BLOCK_SIZE_M=config_ffn1["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"], BLOCK_SIZE_N=config_ffn1["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"], BLOCK_SIZE_K=config_ffn1["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"], GROUP_SIZE_M=config_ffn1["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False, MUL_ROUTED_WEIGHT=False,
top_k=1, top_k=1,
compute_type_enum=1, compute_type_enum=1,
use_fp8_w8a8=True, use_fp8_w8a8=True,
use_int8_w8a16=False, use_int8_w8a16=False,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
) )
intermediate_cache2 = paddle.incubate.nn.functional.swiglu( intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1) intermediate_cache1)
hadamard_matrix = create_hadamard_matrix_map[moe_intermediate_size] intermediate_cache2 = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
intermediate_cache2 = paddle.matmul( intermediate_cache2,
intermediate_cache2.cast("float32"), hadamard_matrix) scale=layer.moe_ffn2_in_scale,
quant_activation_scale = layer.moe_ffn2_in_scale[topk_ids].reshape( topk_ids=topk_ids,
[-1, 1]) top_k=top_k,
intermediate_cache2 = intermediate_cache2 / quant_activation_scale intermediate_size=moe_intermediate_size,
intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn") tiled=True)
config_ffn2 = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
grid = ( grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid]( fused_moe_kernel_paddle[grid](
intermediate_cache2, intermediate_cache2,
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn), layer.moe_ffn2_weight,
intermediate_cache3, intermediate_cache3,
layer.moe_ffn2_in_scale, layer.moe_ffn2_in_scale,
layer.moe_ffn2_weight_scale, layer.moe_ffn2_weight_scale,
@@ -465,16 +471,16 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
group_n=-1, group_n=-1,
group_k=-1, group_k=-1,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"], BLOCK_SIZE_M=config_ffn2["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"], BLOCK_SIZE_N=config_ffn2["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"], BLOCK_SIZE_K=config_ffn2["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"], GROUP_SIZE_M=config_ffn2["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True, MUL_ROUTED_WEIGHT=True,
top_k=1, top_k=1,
compute_type_enum=1, compute_type_enum=1,
use_fp8_w8a8=True, use_fp8_w8a8=True,
use_int8_w8a16=False, use_int8_w8a16=False,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
) )
intermediate_cache3.reshape_([token_num, top_k, hidden_size]) intermediate_cache3.reshape_([token_num, top_k, hidden_size])

View File

@@ -15,8 +15,6 @@
""" """
from typing import Optional from typing import Optional
import paddle
from fastdeploy.model_executor.layers.moe import FusedMoE from fastdeploy.model_executor.layers.moe import FusedMoE
from ..utils import get_tensor from ..utils import get_tensor
@@ -113,15 +111,10 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
""" """
compute! compute!
""" """
from fastdeploy.model_executor.ops.gpu import \ from fastdeploy.model_executor.ops.gpu import (
cutlass_fp8_fp8_half_gemm_fused cutlass_fp8_fp8_half_gemm_fused, fused_hadamard_quant_fp8)
from ..utils import create_hadamard_matrix_map fp8_x = fused_hadamard_quant_fp8(x, scale=self.act_scale)
hadamard_matrix = create_hadamard_matrix_map[x.shape[-1]]
new_x = paddle.matmul(x.cast("float32"), hadamard_matrix)
fp8_x = new_x / self.act_scale
fp8_x = fp8_x.astype("float8_e4m3fn")
linear_out = cutlass_fp8_fp8_half_gemm_fused( linear_out = cutlass_fp8_fp8_half_gemm_fused(
fp8_x, fp8_x,