mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -12,12 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
from fastdeploy.model_executor.ops.gpu import gemm_dequant
|
||||
from fastdeploy.model_executor.ops.gpu import dequant_int8
|
||||
from itertools import product
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import dequant_int8, gemm_dequant
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
@@ -43,9 +44,7 @@ class Test(unittest.TestCase):
|
||||
act_int_tensor = (act * 128).astype("int8")
|
||||
weight_int_tensor = (weight * 128).astype("int8")
|
||||
scale = paddle.rand([n])
|
||||
linear_out = paddle.matmul(
|
||||
act_int_tensor, weight_int_tensor, transpose_y=True
|
||||
)
|
||||
linear_out = paddle.matmul(act_int_tensor, weight_int_tensor, transpose_y=True)
|
||||
result = dequant_int8(linear_out, scale, "bfloat16")
|
||||
|
||||
result_gemm_dequant = gemm_dequant(
|
||||
@@ -55,7 +54,10 @@ class Test(unittest.TestCase):
|
||||
out_dtype="bfloat16",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
result.numpy(), result_gemm_dequant.numpy(), rtol=1e-05, atol=1e-05
|
||||
result.numpy(),
|
||||
result_gemm_dequant.numpy(),
|
||||
rtol=1e-05,
|
||||
atol=1e-05,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user