mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-21 15:49:31 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -11,16 +11,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" UT for air_topp_sampling kernel """
|
||||
|
||||
import paddle
|
||||
import subprocess
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
import fastdeploy.model_executor.ops.gpu
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize.
|
||||
@@ -29,22 +32,32 @@ class Test(unittest.TestCase):
|
||||
np.random.seed(42)
|
||||
print(paddle.device.cuda.get_device_properties())
|
||||
print(paddle.__git_commit__)
|
||||
nvcc_output = subprocess.check_output(["nvcc", "--version"],
|
||||
universal_newlines=True)
|
||||
output = nvcc_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
self.nvcc_cuda_version = float(output[release_idx].split(",")[0])
|
||||
|
||||
def test_air_topp_sampling(self):
|
||||
"""
|
||||
Check air_topp_sampling output with paddle.tensor.top_p_sampling.
|
||||
"""
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
if cc < 89:
|
||||
self.skipTest("air_topp_sampling only support sm89+")
|
||||
x = paddle.randn([1, 100])
|
||||
if self.nvcc_cuda_version < 12.0:
|
||||
self.skipTest("air_topp_sampling only support cu12+")
|
||||
bsz = 8
|
||||
vocab_size = 103424
|
||||
x = paddle.randn([bsz, vocab_size])
|
||||
x = paddle.nn.functional.softmax(x)
|
||||
x = paddle.cast(x, "float32")
|
||||
top_ps = paddle.to_tensor(np.random.uniform(0, 1, [1]).astype(np.float32))
|
||||
out = fastdeploy.model_executor.ops.gpu.air_topp_sampling(
|
||||
x.cuda(), top_ps.cuda(), None, None, seed=0, k=1, mode="truncated"
|
||||
)
|
||||
top_ps = paddle.to_tensor(
|
||||
np.random.uniform(0, 1, [bsz]).astype(np.float32))
|
||||
_, next_tokens = fastdeploy.model_executor.ops.gpu.air_topp_sampling(
|
||||
x.cuda(), top_ps.cuda(), None, None, seed=0, k=1, mode="truncated")
|
||||
print(next_tokens)
|
||||
less_than_zero = next_tokens >= 0
|
||||
greater_than_vocab_size = next_tokens <= vocab_size
|
||||
accuracy = paddle.logical_and(less_than_zero, greater_than_vocab_size)
|
||||
print(f'Accuracy of results: {accuracy}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
101
test/operators/test_cutlass_scaled_mm.py
Normal file
101
test/operators/test_cutlass_scaled_mm.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" UT for air_topp_sampling kernel """
|
||||
|
||||
import subprocess
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||
cutlass_scaled_mm, scaled_fp8_quant)
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize.
|
||||
"""
|
||||
paddle.seed(2024)
|
||||
np.random.seed(42)
|
||||
self.prop = paddle.device.cuda.get_device_properties()
|
||||
self.sm_version = self.prop.major * 10 + self.prop.minor
|
||||
print(self.prop)
|
||||
print(paddle.__git_commit__)
|
||||
nvcc_output = subprocess.check_output(["nvcc", "--version"],
|
||||
universal_newlines=True)
|
||||
output = nvcc_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
self.nvcc_cuda_version = float(output[release_idx].split(",")[0])
|
||||
|
||||
def test_cutlass_scaled_mm_fp8(self):
|
||||
"""
|
||||
Check cutlass_scaled_mm output.
|
||||
"""
|
||||
if self.sm_version < 89:
|
||||
self.skipTest(
|
||||
"cutlass_scaled_mm with fp8 input only support sm89+")
|
||||
M = 32
|
||||
N = 1024
|
||||
K = 1024
|
||||
a = paddle.rand([M, K], dtype=paddle.bfloat16)
|
||||
b = paddle.rand([N, K], dtype=paddle.bfloat16)
|
||||
b_q, b_scales = scaled_fp8_quant(b, use_per_token_if_dynamic=False)
|
||||
a_q, a_scales = scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
|
||||
# Ensure quantized tensors and scales are valid
|
||||
assert a_q.numel() > 0, "Quantized tensor 'a_q' must not be empty"
|
||||
assert b_q.numel() > 0, "Quantized tensor 'b_q' must not be empty"
|
||||
assert a_scales.numel(
|
||||
) > 0, "Scale tensor 'a_scales' must not be empty"
|
||||
assert b_scales.numel(
|
||||
) > 0, "Scale tensor 'b_scales' must not be empty"
|
||||
|
||||
bias = paddle.rand([N], dtype=paddle.bfloat16)
|
||||
baseline = paddle.matmul(a, b, transpose_x=False, transpose_y=True)
|
||||
if bias is not None:
|
||||
baseline = paddle.add(baseline, bias)
|
||||
out_type = a.dtype
|
||||
c = cutlass_scaled_mm(a_q, b_q, a_scales, b_scales, out_type, bias)
|
||||
euqal = np.allclose(baseline.numpy(), c.numpy(), rtol=1e-2, atol=1e-2)
|
||||
print(euqal) #
|
||||
|
||||
def test_cutlass_scaled_mm_int8(self):
|
||||
"""
|
||||
Check cutlass_scaled_mm output.
|
||||
"""
|
||||
M = 32
|
||||
N = 1024
|
||||
K = 512
|
||||
a = paddle.rand([M, K], dtype=paddle.bfloat16)
|
||||
b = paddle.rand([N, K], dtype=paddle.bfloat16)
|
||||
a_scales = (a.cast(paddle.float32).abs().max(axis=-1) / 127)[:, None]
|
||||
a_q = paddle.clip(a / a_scales, -127, 127).cast(paddle.int8)
|
||||
b_scales = (b.cast(paddle.float32).abs().max(axis=-1) / 127)[:, None]
|
||||
b_q = paddle.clip(b / b_scales, -127, 127).cast(paddle.int8)
|
||||
|
||||
bias = paddle.rand([N], dtype=paddle.bfloat16)
|
||||
baseline = paddle.matmul(a, b, transpose_x=False, transpose_y=True)
|
||||
if bias is not None:
|
||||
baseline = paddle.add(baseline, bias)
|
||||
out_type = a.dtype
|
||||
c = cutlass_scaled_mm(a_q, b_q, a_scales, b_scales, out_type, bias)
|
||||
euqal = np.allclose(baseline.numpy(), c.numpy(), rtol=1e-2, atol=1e-2)
|
||||
print(euqal) #
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -11,16 +11,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" UT for air_topp_sampling kernel """
|
||||
|
||||
import os
|
||||
import paddle
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize.
|
||||
@@ -33,10 +33,10 @@ class Test(unittest.TestCase):
|
||||
"""
|
||||
Check air_topp_sampling output with paddle.tensor.top_p_sampling.
|
||||
"""
|
||||
if dynamic_mode:
|
||||
os.environ["ELLM_DYNAMIC_MODE"] = "1"
|
||||
if not dynamic_mode:
|
||||
paddle.enable_static()
|
||||
else:
|
||||
os.environ["ELLM_DYNAMIC_MODE"] = "0"
|
||||
paddle.disable_static()
|
||||
from fastdeploy.model_executor.ops.gpu import dequant_int8
|
||||
|
||||
input_tensor = paddle.cast(paddle.ones([128, 128]), "int32")
|
||||
@@ -46,10 +46,14 @@ class Test(unittest.TestCase):
|
||||
|
||||
def test(self):
|
||||
op_out = self.dequant_int8_test()
|
||||
exe = paddle.static.Executor()
|
||||
exe.run(paddle.static.default_startup_program())
|
||||
op_out = exe.run(fetch_list=[op_out])[0]
|
||||
func_out = self.dequant_int8_test(True)
|
||||
np.testing.assert_allclose(
|
||||
op_out.numpy(), func_out.numpy(), rtol=1e-04, atol=1e-04
|
||||
)
|
||||
np.testing.assert_allclose(op_out,
|
||||
func_out.numpy(),
|
||||
rtol=1e-04,
|
||||
atol=1e-04)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
66
test/operators/test_rejection_top_p_sampling.py
Normal file
66
test/operators/test_rejection_top_p_sampling.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
|
||||
|
||||
class TestRejectionTopPSampling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Initialize common test data"""
|
||||
self.batch_size = 10
|
||||
self.vocab_size = 103424
|
||||
paddle.seed(2023)
|
||||
|
||||
# Generate test data once for all tests
|
||||
self.pre_norm_prob_np = np.random.rand(self.batch_size, self.vocab_size).astype(np.float32)
|
||||
self.paddle_pre_norm_prob = paddle.to_tensor(self.pre_norm_prob_np)
|
||||
self.paddle_norm_prob = self.paddle_pre_norm_prob / self.paddle_pre_norm_prob.sum(axis=-1, keepdim=True)
|
||||
|
||||
def test_top_p_sampling_reject_case1(self):
|
||||
"""Test with fixed top_p=0.8 and different random seeds"""
|
||||
top_p_paddle = paddle.full((self.batch_size,), 0.8)
|
||||
|
||||
# Test with different seeds
|
||||
for seed in [1024, 2033, 2033]:
|
||||
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, seed)
|
||||
self._validate_samples(samples)
|
||||
|
||||
# Basic validation
|
||||
self.assertTrue(paddle.all(samples >= 0))
|
||||
self.assertTrue(paddle.all(samples < self.vocab_size))
|
||||
|
||||
def test_top_p_sampling_reject_case2(self):
|
||||
"""Test with varying top_p values across batch"""
|
||||
top_p_paddle = paddle.uniform(shape=[self.batch_size], min=0.1, max=1.0)
|
||||
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, -1)
|
||||
|
||||
self._validate_samples(samples)
|
||||
|
||||
# Additional check that we're getting different results for different top_p
|
||||
unique_samples = len(paddle.unique(samples))
|
||||
print(f"Unique samples: {unique_samples}")
|
||||
self.assertGreater(unique_samples, 1) # Should have some diversity
|
||||
|
||||
def _validate_samples(self, samples):
|
||||
"""Common validation for all test cases"""
|
||||
self.assertTrue(paddle.all(samples >= 0))
|
||||
self.assertTrue(paddle.all(samples < self.vocab_size))
|
||||
|
||||
# Check dtype
|
||||
self.assertEqual(samples.dtype, paddle.int64)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user