mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix w4afp8 (#5634)
This commit is contained in:
@@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
|
||||
|
||||
template <int Experts>
|
||||
auto get_gmem_layout(const int Rows, const int Cols) {
|
||||
return make_layout(make_shape(static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
return make_layout(
|
||||
make_shape(static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
|
||||
}
|
||||
|
||||
template <int Experts>
|
||||
auto get_scale_layout(const int Rows, const int Cols) {
|
||||
return make_layout(make_shape(static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(cute::_1{},
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
return make_layout(
|
||||
make_shape(static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(cute::_1{},
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
|
||||
}
|
||||
|
||||
template <typename InputType,
|
||||
|
||||
@@ -85,7 +85,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
||||
"""
|
||||
|
||||
# [M, K, Number of experts, token Padding Size, weight K group size]
|
||||
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]]
|
||||
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128], [256, 5120, 128, 0, 128]]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user