mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix w4afp8 (#5634)
This commit is contained in:
@@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
|
||||
|
||||
template <int Experts>
|
||||
auto get_gmem_layout(const int Rows, const int Cols) {
|
||||
return make_layout(make_shape(static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
return make_layout(
|
||||
make_shape(static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
|
||||
}
|
||||
|
||||
template <int Experts>
|
||||
auto get_scale_layout(const int Rows, const int Cols) {
|
||||
return make_layout(make_shape(static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(cute::_1{},
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
return make_layout(
|
||||
make_shape(static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(cute::_1{},
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
|
||||
}
|
||||
|
||||
template <typename InputType,
|
||||
|
||||
Reference in New Issue
Block a user