fix w4afp8 (#5634)

This commit is contained in:
lizexu123
2025-12-22 13:39:41 +08:00
committed by GitHub
parent 6eada4929d
commit 6d323769dd
3 changed files with 59 additions and 47 deletions

View File

@@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
template <int Experts>
auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
return make_layout(
make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
}
template <int Experts>
auto get_scale_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows * Cols)));
return make_layout(
make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
}
template <typename InputType,