Files
FastDeploy/test/operators/test_scaled_gemm_f8_i4_f16.py
2025-06-09 19:20:15 +08:00

106 lines
3.2 KiB
Python

# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" UT for fp8_int4_gemm kernel """
import paddle
import unittest
import numpy as np
from fastdeploy.model_executor.ops.gpu import (
scaled_gemm_f8_i4_f16,
scaled_gemm_f8_i4_f16_weight_quantize,
)
class Test(unittest.TestCase):
def setUp(self):
"""
Initialize.
"""
paddle.seed(2024)
print(paddle.device.cuda.get_device_properties())
print(paddle.__git_commit__)
def quant_fp8_pertensor(self, tensor):
"""
quant_fp8_pertensor
"""
scale = paddle.max(paddle.abs(tensor))
tensor = paddle.cast(
(tensor * 448 / scale).clip(-448, 448), "float8_e4m3fn"
).astype(tensor.dtype)
return tensor, scale
def dequant_fp8_pertensor(self, tensor, scale):
"""
dequant_fp8_pertensor
"""
tensor = (tensor / 448 * scale).astype(tensor.dtype)
return tensor
def quant_int4_fp8_matmul(self, A, B, dtype):
"""
quant_int4_fp8_matmul
"""
A_fp8, A_fp8_scale = self.quant_fp8_pertensor(A)
B_fp8, B_fp8_scale = self.quant_fp8_pertensor(B)
processed_B, w_scale = scaled_gemm_f8_i4_f16_weight_quantize(
B_fp8, groupsize=-1, scale_dtype="float16"
)
w_scale = paddle.view(w_scale, dtype)
out_scale = (A_fp8_scale / 448) * (B_fp8_scale / 448)
out = scaled_gemm_f8_i4_f16(
x=paddle.cast(A_fp8, "float8_e4m3fn").cuda(),
y=processed_B.cuda(),
scale=w_scale.cuda(),
zero_points=None,
bias=None,
out_scale=out_scale,
groupsize=0,
out_dtype=dtype,
)
return out
def test_fp16(self):
"""
Check fp16.
"""
A_fp32 = paddle.ones((4, 128)).clip(-448, 448)
B_fp32 = paddle.ones((128, 512)).clip(-448, 448)
C_fp32 = paddle.matmul(A_fp32, B_fp32)
out = self.quant_int4_fp8_matmul(A_fp32, B_fp32, "float16")
out = paddle.cast(out, "float32")
np.testing.assert_allclose(C_fp32.numpy(), out.numpy(), rtol=1e-04, atol=1e-04)
def test_bf16(self):
"""
Check bf16.
"""
A_fp32 = paddle.ones((4, 128)).clip(-448, 448)
B_fp32 = paddle.ones((128, 512)).clip(-448, 448)
C_fp32 = paddle.matmul(A_fp32, B_fp32)
out = self.quant_int4_fp8_matmul(A_fp32, B_fp32, "bfloat16")
out = paddle.cast(out, "float32")
np.testing.assert_allclose(C_fp32.numpy(), out.numpy(), rtol=1e-04, atol=1e-04)
if __name__ == "__main__":
unittest.main()