diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp index 587855c5b..36f1ab0c9 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp @@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, template auto get_gmem_layout(const int Rows, const int Cols) { - return make_layout(make_shape(static_cast(Rows), - static_cast(Cols), - static_cast(Experts)), - make_stride(static_cast(Cols), - cute::_1{}, - static_cast(Rows * Cols))); + return make_layout( + make_shape(static_cast(Rows), + static_cast(Cols), + static_cast(Experts)), + make_stride(static_cast(Cols), + cute::_1{}, + static_cast(Rows) * static_cast(Cols))); } template auto get_scale_layout(const int Rows, const int Cols) { - return make_layout(make_shape(static_cast(Cols), - static_cast(Rows), - static_cast(Experts)), - make_stride(cute::_1{}, - static_cast(Cols), - static_cast(Rows * Cols))); + return make_layout( + make_shape(static_cast(Cols), + static_cast(Rows), + static_cast(Experts)), + make_stride(cute::_1{}, + static_cast(Cols), + static_cast(Rows) * static_cast(Cols))); } template