mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
This reverts commit 93fcf7e4ec.
This commit is contained in:
@@ -17,20 +17,16 @@ import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
w4afp8_gemm,
|
||||
w4afp8_gemm_scale_permute,
|
||||
w4afp8_gemm_weight_convert,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert
|
||||
|
||||
|
||||
class TestW4AFP8GEMM(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.seed(0)
|
||||
self.tokens_per_group = 1
|
||||
self.N = 1792
|
||||
self.K = 8192
|
||||
self.BATCH = 64
|
||||
self.tokens_per_group = 256
|
||||
self.N = 256
|
||||
self.K = 256
|
||||
self.BATCH = 1
|
||||
self.TokenPadding = 0
|
||||
|
||||
tokens = [self.tokens_per_group] * self.BATCH
|
||||
@@ -42,15 +38,14 @@ class TestW4AFP8GEMM(unittest.TestCase):
|
||||
|
||||
self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
|
||||
self.input_bf16 = self.input_fp8.astype("bfloat16")
|
||||
self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16")
|
||||
self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") / 10
|
||||
|
||||
self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1])
|
||||
self.weight_quant = (self.weight * self.weight_scale).astype("int")
|
||||
self.weight_quant = paddle.clip(self.weight_quant, -7, 7)
|
||||
self.weight_quant_naive = self.weight_quant.astype("float32")
|
||||
self.weight_quant = (self.weight * self.weight_scale).astype("int") + 7
|
||||
self.weight_quant = paddle.clip(self.weight_quant, 0, 14)
|
||||
self.weight_quant = self.weight_quant.astype("bfloat16")
|
||||
self.weight_quant = paddle.where(self.weight_quant > 0, self.weight_quant, 8 - self.weight_quant)
|
||||
self.weight_dequant_scale = 1 / self.weight_scale.astype("float32")
|
||||
self.input_row_sum = self.input_bf16.sum(axis=1) * -7 / 512
|
||||
self.max_tokens = int(self.tokens.max())
|
||||
|
||||
def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale):
|
||||
@@ -59,7 +54,7 @@ class TestW4AFP8GEMM(unittest.TestCase):
|
||||
pre_fix_token = 0
|
||||
for i in range(self.BATCH):
|
||||
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
|
||||
weight = weight_quant[i] * weight_dequant_scale[i]
|
||||
weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
|
||||
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
|
||||
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
|
||||
pre_fix_token += tokens[i]
|
||||
@@ -76,53 +71,37 @@ class TestW4AFP8GEMM(unittest.TestCase):
|
||||
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
|
||||
return weight_scale
|
||||
|
||||
def get_per_group_scale(self, processed_weight_scale):
|
||||
processed_weight_scale = processed_weight_scale.repeat_interleave(self.K // 128, axis=-1)
|
||||
origin_shape = processed_weight_scale.shape
|
||||
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1])
|
||||
processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]])
|
||||
|
||||
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale)
|
||||
processed_weight_scale = processed_weight_scale.reshape(
|
||||
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
|
||||
)
|
||||
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3])
|
||||
return processed_weight_scale
|
||||
|
||||
def test_w4afp8_gemm(self):
|
||||
out_naive = self.w4afp8_gemm_naive(
|
||||
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
|
||||
)
|
||||
out_naive = self.w4afp8_gemm_naive(self.input_bf16, self.weight_quant, self.tokens, self.weight_dequant_scale)
|
||||
|
||||
# weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
|
||||
weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512)
|
||||
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda()
|
||||
weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
|
||||
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu())
|
||||
|
||||
if self.TokenPadding == 0:
|
||||
out_cuda = w4afp8_gemm(
|
||||
self.input_fp8,
|
||||
weight_int4,
|
||||
weight_int4.cuda(),
|
||||
self.tokens_prefix_sum,
|
||||
self.input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
None,
|
||||
int(self.TokenPadding),
|
||||
self.all_tokens,
|
||||
self.max_tokens,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
out_cuda = w4afp8_gemm(
|
||||
self.input_fp8,
|
||||
weight_int4,
|
||||
weight_int4.cuda(),
|
||||
self.tokens,
|
||||
self.input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
None,
|
||||
int(self.TokenPadding),
|
||||
self.max_tokens,
|
||||
True,
|
||||
)
|
||||
|
||||
gap = (out_cuda - out_naive).abs()
|
||||
self.assertLess(float(gap.mean()), 0.11)
|
||||
self.assertLess(float(gap.mean()), 0.07)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user