diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 5eb56c14f..97c212691 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -468,6 +468,28 @@ std::vector NoauxTc( int topk, float routed_scaling_factor); +paddle::Tensor cutlass_fp8_fp8_half_gemm_func( + const paddle::Tensor& x, + const paddle::Tensor& y, + const paddle::optional& bias, + bool trans_x, + bool trans_y, + float scale, // only support per-tensor quantization + std::string output_dtype, + std::string activation_type); + +paddle::Tensor MoeFusedHadamardQuantFp8Func( + const paddle::Tensor &input, + const paddle::Tensor &scale, + const paddle::Tensor &topk_ids, + const int top_k, + const int intermediate_size, + const bool tiled); + +paddle::Tensor FusedHadamardQuantFp8Func( + const paddle::Tensor &input, + const float scale); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -697,38 +719,21 @@ PYBIND11_MODULE(fastdeploy_ops, m) { "text_image_gather_scatter function"); m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func); + m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel); m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi, - py::arg("a"), - py::arg("c_or_none"), - py::arg("b_q_weight"), - py::arg("b_scales"), - py::arg("global_scale_or_none"), - py::arg("b_zeros_or_none"), - py::arg("g_idx_or_none"), - py::arg("perm_or_none"), - py::arg("workspace"), - py::arg("sorted_token_ids"), - py::arg("expert_ids"), - py::arg("num_tokens_post_padded"), - py::arg("topk_weights"), - py::arg("moe_block_size"), - py::arg("top_k"), - py::arg("mul_topk_weights"), - py::arg("is_ep"), - py::arg("b_q_type_str"), - py::arg("size_m"), - py::arg("size_n"), - py::arg("size_k"), - py::arg("is_k_full"), - py::arg("use_atomic_add"), - py::arg("use_fp32_reduce"), - py::arg("is_zp_float")); + py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"), + py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"), + py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"), + py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"), + py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"), + py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"), + py::arg("use_fp32_reduce"), py::arg("is_zp_float")); + m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch, "get_position_ids_and_mask_encoder_batch function"); - /** * cutlass_scaled_mm.cu * cutlass_scaled_mm @@ -753,6 +758,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant, "dynamic_per_token_scaled_fp8_quant function", py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub")); + m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function"); m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function"); @@ -762,4 +768,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function"); m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + + m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func, + py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"), + py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"), + py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function"); + + m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func, + py::arg("input"), py::arg("scale"), py::arg("topk_ids"), + py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function"); + + m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func, + py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function"); } diff --git a/custom_ops/gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu b/custom_ops/gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu index 76d087a07..c62b7effa 100644 --- a/custom_ops/gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu +++ b/custom_ops/gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu @@ -19,7 +19,7 @@ #include "fp8_fp8_half_cuda_core_gemm.h" -std::vector cutlass_fp8_fp8_half_gemm( +paddle::Tensor cutlass_fp8_fp8_half_gemm_func( const paddle::Tensor& x, const paddle::Tensor& y, const paddle::optional& bias, @@ -142,7 +142,7 @@ std::vector cutlass_fp8_fp8_half_gemm( { if(output_dtype == "bfloat16") { cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params); - + } else { cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params); } @@ -174,7 +174,21 @@ std::vector cutlass_fp8_fp8_half_gemm( fuse_gemm_config}; fp8_fp8_gemm_scale_bias_act(params); } - return {out}; + return out; +} + +std::vector cutlass_fp8_fp8_half_gemm( + const paddle::Tensor& x, + const paddle::Tensor& y, + const paddle::optional& bias, + bool trans_x, + bool trans_y, + float scale, // only support per-tensor quantization + std::string output_dtype, + std::string activation_type) { + return {cutlass_fp8_fp8_half_gemm_func( + x, y, bias, trans_x, trans_y, scale, + output_dtype, activation_type)}; } std::vector> CutlassFp8Fp8HalfGemmFusedInferShape( diff --git a/custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu b/custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu new file mode 100644 index 000000000..6ad190102 --- /dev/null +++ b/custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu @@ -0,0 +1,198 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "helper.h" + +__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) { + int lane_id = threadIdx.x % 32; +#pragma unroll + for (int step = 0; step < 5; ++step) { + const int lane_mask = 1 << step; + const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f; + __nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask); + x = sign * x + x_val_other; + } +} + +__global__ void MoeFusedHadamardQuantFp8Kernel( + const __nv_bfloat16* __restrict__ input, + const float* __restrict__ scale, + const int64_t* __restrict__ topk_ids, + __nv_fp8_e4m3* out, + const int top_k, + const int intermediate_size, + const int64_t numel +) { + int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (out_idx >= numel) return; + + int64_t token_idx = out_idx / (top_k * intermediate_size); + int64_t topk_idx = (out_idx / intermediate_size) % top_k; + int64_t inter_idx = out_idx % intermediate_size; + + int64_t input_idx = token_idx * intermediate_size + inter_idx; + if (input_idx >= numel / top_k) return; + + int64_t expert_id = topk_ids[token_idx * top_k + topk_idx]; + float scale_value = scale[expert_id]; + + __nv_bfloat16 x = input[input_idx]; + hadamard32_warp(x); + + float x_fp32 = __bfloat162float(x); + float quantized = x_fp32 / scale_value; + out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized); +} + +__global__ void MoeFusedHadamardQuantFp8TiledKernel( + const __nv_bfloat16* __restrict__ input, + const float* __restrict__ scale, + const int64_t* __restrict__ topk_ids, + __nv_fp8_e4m3* out, + const int top_k, + const int intermediate_size, + const int64_t numel +) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= numel) return; + + int64_t token_idx = idx / intermediate_size; + int64_t expert_id = topk_ids[token_idx]; + float scale_value = scale[expert_id]; + + __nv_bfloat16 x = input[idx]; + hadamard32_warp(x); + + float x_fp32 = __bfloat162float(x); + float quantized = x_fp32 / scale_value; + out[idx] = static_cast<__nv_fp8_e4m3>(quantized); +} + +std::vector MoeFusedHadamardQuantFp8( + const paddle::Tensor &input, + const paddle::Tensor &scale, + const paddle::Tensor &topk_ids, + const int top_k, + const int intermediate_size, + const bool tiled) { + int64_t numel = input.numel(); + if (!tiled) numel *= top_k; + paddle::Tensor out = GetEmptyTensor( + {numel / intermediate_size, intermediate_size}, + paddle::DataType::FLOAT8_E4M3FN, + input.place()); + constexpr int64_t thread_per_block = 256; + int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block; + auto stream = input.stream(); + if (tiled) { + MoeFusedHadamardQuantFp8TiledKernel<<>>( + reinterpret_cast(input.data()), + scale.data(), + topk_ids.data(), + reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data()), + top_k, + intermediate_size, + numel + ); + } else { + MoeFusedHadamardQuantFp8Kernel<<>>( + reinterpret_cast(input.data()), + scale.data(), + topk_ids.data(), + reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data()), + top_k, + intermediate_size, + numel + ); + } + return {out}; +} + +PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8) + .Inputs({"input", "scale", "topk_ids"}) + .Outputs({"output"}) + .Attrs({"top_k: int", + "intermediate_size: int", + "tiled: bool"}) + .SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8)); + + +paddle::Tensor MoeFusedHadamardQuantFp8Func( + const paddle::Tensor &input, + const paddle::Tensor &scale, + const paddle::Tensor &topk_ids, + const int top_k, + const int intermediate_size, + const bool tiled) { + return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0]; +} + + +__global__ void FusedHadamardQuantFp8Kernel( + const __nv_bfloat16* __restrict__ input, + __nv_fp8_e4m3* out, + const float scale, + const int64_t numel) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= numel) return; + + __nv_bfloat16 x = input[idx]; + hadamard32_warp(x); + + float x_fp32 = __bfloat162float(x); + float quantized = x_fp32 / scale; + out[idx] = static_cast<__nv_fp8_e4m3>(quantized); +} + +std::vector FusedHadamardQuantFp8( + const paddle::Tensor &input, + const float scale) { + int64_t numel = input.numel(); + paddle::Tensor out = GetEmptyTensor( + input.dims(), + paddle::DataType::FLOAT8_E4M3FN, + input.place()); + constexpr int64_t thread_per_block = 256; + int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block; + auto stream = input.stream(); + FusedHadamardQuantFp8Kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data()), + scale, + numel + ); + return {out}; +} + +PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8) + .Inputs({"input"}) + .Outputs({"output"}) + .Attrs({"scale: float"}) + .SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8)); + + +paddle::Tensor FusedHadamardQuantFp8Func( + const paddle::Tensor &input, + const float scale) { + return FusedHadamardQuantFp8(input, scale)[0]; +} diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index eca3349ac..75a9f4621 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -442,6 +442,7 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu", "gpu_ops/cutlass_kernels/cutlass_heuristic.cu", "gpu_ops/cutlass_kernels/cutlass_preprocessors.cu", + "gpu_ops/fused_hadamard_quant_fp8.cu" ] sources += find_end_files(fp8_auto_gen_directory, ".cu") diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 267dab451..801953087 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1,5 +1,5 @@ """ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,7 @@ from paddle import nn import fastdeploy from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce -from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map, - get_tensor) +from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.utils import ceil_div from ..quantization.quant_base import QuantMethodBase @@ -272,8 +271,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): layer.moe_intermediate_size, layer.hidden_size ] - ffn1_tensor = paddle.stack(ffn1_tensor, axis=0) - ffn2_tensor = paddle.stack(ffn2_tensor, axis=0) + ffn1_tensor = paddle.stack(ffn1_tensor, axis=0).view(paddle.float8_e4m3fn) + ffn2_tensor = paddle.stack(ffn2_tensor, axis=0).view(paddle.float8_e4m3fn) added_wfp8afp8_attrs = [ "moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale", @@ -309,7 +308,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): dtype=weight_tensor.dtype, default_initializer=paddle.nn.initializer.Constant(0), )) - getattr(layer, name).set_value(weight_tensor) + if weight_tensor.dtype == paddle.float8_e4m3fn: + getattr(layer, name).copy_(weight_tensor, False) + else: + getattr(layer, name).set_value(weight_tensor) def create_weights(self, layer: nn.Layer, state_dict): """ @@ -333,13 +335,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - scores = paddle.nn.functional.softmax(gate_out, axis=-1) - - topk_weights, topk_ids = paddle.topk(scores, - k=top_k, - axis=-1, - sorted=False) - topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True) + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight, + False, + ) intermediate_cache1 = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], @@ -354,34 +356,31 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): dtype=x.dtype, ) - config = { + config_ffn1 = { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, } sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( - topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) + topk_ids, num_local_experts, config_ffn1["BLOCK_SIZE_M"]) max_possible_num_post_padded = sorted_token_ids.shape[0] grid = ( - ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * - ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) + ceil_div(max_possible_num_post_padded, config_ffn1["BLOCK_SIZE_M"]) * + ceil_div(moe_intermediate_size * 2, config_ffn1["BLOCK_SIZE_N"]), ) - adamard_matrix = create_hadamard_matrix_map[hidden_size] - x = paddle.matmul(x.cast("float32"), adamard_matrix) - - permute_x = x[:, None, :].tile([1, top_k, 1]) - permute_x = permute_x.reshape([-1, hidden_size]) - - quant_activation_scale = layer.moe_ffn1_in_scale[topk_ids].reshape( - [-1, 1]) - permute_x = permute_x / quant_activation_scale - permute_x = permute_x.astype("float8_e4m3fn") + permute_x = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8( + x, + scale=layer.moe_ffn1_in_scale, + topk_ids=topk_ids, + top_k=top_k, + intermediate_size=hidden_size, + tiled=False) fused_moe_kernel_paddle[grid]( permute_x, - layer.moe_ffn1_weight.view(paddle.float8_e4m3fn), + layer.moe_ffn1_weight, intermediate_cache1, layer.moe_ffn1_in_scale, layer.moe_ffn1_weight_scale, @@ -409,36 +408,43 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): group_n=-1, group_k=-1, # Meta-parameters - BLOCK_SIZE_M=config["BLOCK_SIZE_M"], - BLOCK_SIZE_N=config["BLOCK_SIZE_N"], - BLOCK_SIZE_K=config["BLOCK_SIZE_K"], - GROUP_SIZE_M=config["GROUP_SIZE_M"], + BLOCK_SIZE_M=config_ffn1["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config_ffn1["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config_ffn1["BLOCK_SIZE_K"], + GROUP_SIZE_M=config_ffn1["GROUP_SIZE_M"], MUL_ROUTED_WEIGHT=False, top_k=1, compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, - even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, + even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0, ) intermediate_cache2 = paddle.incubate.nn.functional.swiglu( intermediate_cache1) - hadamard_matrix = create_hadamard_matrix_map[moe_intermediate_size] - intermediate_cache2 = paddle.matmul( - intermediate_cache2.cast("float32"), hadamard_matrix) - quant_activation_scale = layer.moe_ffn2_in_scale[topk_ids].reshape( - [-1, 1]) - intermediate_cache2 = intermediate_cache2 / quant_activation_scale - intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn") + intermediate_cache2 = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8( + intermediate_cache2, + scale=layer.moe_ffn2_in_scale, + topk_ids=topk_ids, + top_k=top_k, + intermediate_size=moe_intermediate_size, + tiled=True) + + config_ffn2 = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } grid = ( - ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * - ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) + ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) * + ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( intermediate_cache2, - layer.moe_ffn2_weight.view(paddle.float8_e4m3fn), + layer.moe_ffn2_weight, intermediate_cache3, layer.moe_ffn2_in_scale, layer.moe_ffn2_weight_scale, @@ -465,16 +471,16 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): group_n=-1, group_k=-1, # Meta-parameters - BLOCK_SIZE_M=config["BLOCK_SIZE_M"], - BLOCK_SIZE_N=config["BLOCK_SIZE_N"], - BLOCK_SIZE_K=config["BLOCK_SIZE_K"], - GROUP_SIZE_M=config["GROUP_SIZE_M"], + BLOCK_SIZE_M=config_ffn2["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config_ffn2["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config_ffn2["BLOCK_SIZE_K"], + GROUP_SIZE_M=config_ffn2["GROUP_SIZE_M"], MUL_ROUTED_WEIGHT=True, top_k=1, compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, - even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, + even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0, ) intermediate_cache3.reshape_([token_num, top_k, hidden_size]) diff --git a/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py index 06992954c..99a8562b8 100644 --- a/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py @@ -15,8 +15,6 @@ """ from typing import Optional -import paddle - from fastdeploy.model_executor.layers.moe import FusedMoE from ..utils import get_tensor @@ -113,15 +111,10 @@ class TensorWiseFP8LinearMethod(QuantMethodBase): """ compute! """ - from fastdeploy.model_executor.ops.gpu import \ - cutlass_fp8_fp8_half_gemm_fused + from fastdeploy.model_executor.ops.gpu import ( + cutlass_fp8_fp8_half_gemm_fused, fused_hadamard_quant_fp8) - from ..utils import create_hadamard_matrix_map - - hadamard_matrix = create_hadamard_matrix_map[x.shape[-1]] - new_x = paddle.matmul(x.cast("float32"), hadamard_matrix) - fp8_x = new_x / self.act_scale - fp8_x = fp8_x.astype("float8_e4m3fn") + fp8_x = fused_hadamard_quant_fp8(x, scale=self.act_scale) linear_out = cutlass_fp8_fp8_half_gemm_fused( fp8_x,