mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix w4afp8 (#5634)
This commit is contained in:
@@ -206,22 +206,24 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
|
||||
|
||||
template <int Experts>
|
||||
auto get_gmem_layout(const int Rows, const int Cols) {
|
||||
return make_layout(make_shape(static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
return make_layout(
|
||||
make_shape(static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
|
||||
}
|
||||
|
||||
template <int Experts>
|
||||
auto get_scale_layout(const int Rows, const int Cols) {
|
||||
return make_layout(make_shape(static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(cute::_1{},
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
return make_layout(
|
||||
make_shape(static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Experts)),
|
||||
make_stride(cute::_1{},
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Rows) * static_cast<int64_t>(Cols)));
|
||||
}
|
||||
|
||||
template <typename InputType,
|
||||
|
||||
@@ -85,7 +85,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
||||
"""
|
||||
|
||||
# [M, K, Number of experts, token Padding Size, weight K group size]
|
||||
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]]
|
||||
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128], [256, 5120, 128, 0, 128]]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
||||
|
||||
@@ -27,12 +27,18 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
class TestW4AFP8GEMM(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.seed(0)
|
||||
self.tokens_per_group = 1
|
||||
self.N = 256
|
||||
self.K = 256
|
||||
self.BATCH = 2
|
||||
self.test_cases = [
|
||||
[1, 2, 256, 256],
|
||||
[4096, 128, 256, 5120],
|
||||
]
|
||||
self.TokenPadding = 0
|
||||
|
||||
def set_data(self, tokens_per_group, Experts, N, K):
|
||||
self.tokens_per_group = tokens_per_group
|
||||
self.N = N
|
||||
self.K = K
|
||||
self.BATCH = Experts
|
||||
|
||||
tokens = [self.tokens_per_group] * self.BATCH
|
||||
self.tokens_prefix_sum = np.cumsum(tokens)
|
||||
|
||||
@@ -90,39 +96,43 @@ class TestW4AFP8GEMM(unittest.TestCase):
|
||||
return processed_weight_scale
|
||||
|
||||
def test_w4afp8_gemm(self):
|
||||
out_naive = self.w4afp8_gemm_naive(
|
||||
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
|
||||
)
|
||||
for test_case in self.test_cases:
|
||||
tokens_per_group, Experts, N, K = test_case
|
||||
|
||||
# weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
|
||||
weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512)
|
||||
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda()
|
||||
|
||||
if self.TokenPadding == 0:
|
||||
out_cuda = w4afp8_gemm(
|
||||
self.input_fp8,
|
||||
weight_int4,
|
||||
self.tokens_prefix_sum,
|
||||
weight_dequant_scale.astype("float32"),
|
||||
None,
|
||||
int(self.TokenPadding),
|
||||
self.all_tokens,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
out_cuda = w4afp8_gemm(
|
||||
self.input_fp8,
|
||||
weight_int4,
|
||||
self.tokens,
|
||||
weight_dequant_scale.astype("float32"),
|
||||
None,
|
||||
int(self.TokenPadding),
|
||||
self.max_tokens,
|
||||
True,
|
||||
self.set_data(tokens_per_group, Experts, N, K)
|
||||
out_naive = self.w4afp8_gemm_naive(
|
||||
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
|
||||
)
|
||||
|
||||
gap = (out_cuda - out_naive).abs()
|
||||
self.assertLess(float(gap.mean()), 0.11)
|
||||
# weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
|
||||
weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512)
|
||||
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda()
|
||||
|
||||
if self.TokenPadding == 0:
|
||||
out_cuda = w4afp8_gemm(
|
||||
self.input_fp8,
|
||||
weight_int4,
|
||||
self.tokens_prefix_sum,
|
||||
weight_dequant_scale.astype("float32"),
|
||||
None,
|
||||
int(self.TokenPadding),
|
||||
self.all_tokens,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
out_cuda = w4afp8_gemm(
|
||||
self.input_fp8,
|
||||
weight_int4,
|
||||
self.tokens,
|
||||
weight_dequant_scale.astype("float32"),
|
||||
None,
|
||||
int(self.TokenPadding),
|
||||
self.max_tokens,
|
||||
True,
|
||||
)
|
||||
|
||||
gap = (out_cuda - out_naive).abs()
|
||||
self.assertLess(float(gap.mean()), 0.11)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user