Files
FastDeploy/tests/operators/test_scaled_gemm_f8_i4_f16.py
YUNSHEN XIE 3a6058e445 Add stable ci (#3460)
* add stable ci

* fix

* update

* fix

* rename tests dir;fix stable ci bug

* add timeout limit

* update
2025-08-20 08:57:17 +08:00

108 lines
3.3 KiB
Python

# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UT for fp8_int4_gemm kernel"""
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import (
scaled_gemm_f8_i4_f16,
scaled_gemm_f8_i4_f16_weight_quantize,
)
class Test(unittest.TestCase):
def setUp(self):
"""
Initialize.
"""
paddle.seed(2024)
print(paddle.device.cuda.get_device_properties())
print(paddle.__git_commit__)
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
if cc != 89:
self.skipTest("scaled_gemm_f8_i4_f16 only support sm 89!")
def quant_fp8_pertensor(self, tensor):
"""
quant_fp8_pertensor
"""
scale = paddle.max(paddle.abs(tensor))
tensor = paddle.cast((tensor * 448 / scale).clip(-448, 448), "float8_e4m3fn").astype(tensor.dtype)
return tensor, scale
def dequant_fp8_pertensor(self, tensor, scale):
"""
dequant_fp8_pertensor
"""
tensor = (tensor / 448 * scale).astype(tensor.dtype)
return tensor
def quant_int4_fp8_matmul(self, A, B, dtype):
"""
quant_int4_fp8_matmul
"""
A_fp8, A_fp8_scale = self.quant_fp8_pertensor(A)
B_fp8, B_fp8_scale = self.quant_fp8_pertensor(B)
processed_B, w_scale = scaled_gemm_f8_i4_f16_weight_quantize(B_fp8, groupsize=-1, scale_dtype="float16")
w_scale = paddle.view(w_scale, dtype)
out_scale = (A_fp8_scale / 448) * (B_fp8_scale / 448)
out = scaled_gemm_f8_i4_f16(
x=paddle.cast(A_fp8, "float8_e4m3fn").cuda(),
y=processed_B.cuda(),
scale=w_scale.cuda(),
zero_points=None,
bias=None,
out_scale=out_scale,
groupsize=0,
out_dtype=dtype,
)
return out
def test_fp16(self):
"""
Check fp16.
"""
A_fp32 = paddle.ones((4, 128)).clip(-448, 448)
B_fp32 = paddle.ones((128, 512)).clip(-448, 448)
C_fp32 = paddle.matmul(A_fp32, B_fp32)
out = self.quant_int4_fp8_matmul(A_fp32, B_fp32, "float16")
out = paddle.cast(out, "float32")
np.testing.assert_allclose(C_fp32.numpy(), out.numpy(), rtol=1e-04, atol=1e-04)
def test_bf16(self):
"""
Check bf16.
"""
A_fp32 = paddle.ones((4, 128)).clip(-448, 448)
B_fp32 = paddle.ones((128, 512)).clip(-448, 448)
C_fp32 = paddle.matmul(A_fp32, B_fp32)
out = self.quant_int4_fp8_matmul(A_fp32, B_fp32, "bfloat16")
out = paddle.cast(out, "float32")
np.testing.assert_allclose(C_fp32.numpy(), out.numpy(), rtol=1e-04, atol=1e-04)
if __name__ == "__main__":
unittest.main()