From 6d323769ddfd2cb8134ac4ce692b71163824f45d Mon Sep 17 00:00:00 2001 From: lizexu123 <39205361+lizexu123@users.noreply.github.com> Date: Mon, 22 Dec 2025 13:39:41 +0800 Subject: [PATCH] fix w4afp8 (#5634) --- .../w4afp8_gemm/w4afp8_gemm_kernel.hpp | 26 ++++--- .../utils/auto_gen_w4afp8_gemm_kernel.py | 2 +- tests/operators/test_w4afp8_gemm.py | 78 +++++++++++-------- 3 files changed, 59 insertions(+), 47 deletions(-) diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp index 587855c5b..36f1ab0c9 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp @@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, template auto get_gmem_layout(const int Rows, const int Cols) { - return make_layout(make_shape(static_cast(Rows), - static_cast(Cols), - static_cast(Experts)), - make_stride(static_cast(Cols), - cute::_1{}, - static_cast(Rows * Cols))); + return make_layout( + make_shape(static_cast(Rows), + static_cast(Cols), + static_cast(Experts)), + make_stride(static_cast(Cols), + cute::_1{}, + static_cast(Rows) * static_cast(Cols))); } template auto get_scale_layout(const int Rows, const int Cols) { - return make_layout(make_shape(static_cast(Cols), - static_cast(Rows), - static_cast(Experts)), - make_stride(cute::_1{}, - static_cast(Cols), - static_cast(Rows * Cols))); + return make_layout( + make_shape(static_cast(Cols), + static_cast(Rows), + static_cast(Experts)), + make_stride(cute::_1{}, + static_cast(Cols), + static_cast(Rows) * static_cast(Cols))); } template