From 04fc7eb93187ba8e9da1f49f319be6946560e324 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Tue, 5 Aug 2025 15:47:50 +0800 Subject: [PATCH] fix test_air_top_p_sampling name (#3211) --- ...air_topp_sampling.py => test_air_top_p_sampling.py} | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) rename test/operators/{test_air_topp_sampling.py => test_air_top_p_sampling.py} (89%) diff --git a/test/operators/test_air_topp_sampling.py b/test/operators/test_air_top_p_sampling.py similarity index 89% rename from test/operators/test_air_topp_sampling.py rename to test/operators/test_air_top_p_sampling.py index d3ec669cd..eebe56a79 100644 --- a/test/operators/test_air_topp_sampling.py +++ b/test/operators/test_air_top_p_sampling.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""UT for air_topp_sampling kernel""" +"""UT for air_top_p_sampling kernel""" import subprocess import unittest @@ -36,19 +36,19 @@ class Test(unittest.TestCase): release_idx = output.index("release") + 1 self.nvcc_cuda_version = float(output[release_idx].split(",")[0]) - def test_air_topp_sampling(self): + def test_air_top_p_sampling(self): """ - Check air_topp_sampling output with paddle.tensor.top_p_sampling. + Check air_top_p_sampling output with paddle.tensor.top_p_sampling. """ if self.nvcc_cuda_version < 12.0: - self.skipTest("air_topp_sampling only support cu12+") + self.skipTest("air_top_p_sampling only support cu12+") bsz = 8 vocab_size = 103424 x = paddle.randn([bsz, vocab_size]) x = paddle.nn.functional.softmax(x) x = paddle.cast(x, "float32") top_ps = paddle.to_tensor(np.random.uniform(0, 1, [bsz]).astype(np.float32)) - _, next_tokens = fastdeploy.model_executor.ops.gpu.air_topp_sampling( + _, next_tokens = fastdeploy.model_executor.ops.gpu.air_top_p_sampling( x.cuda(), top_ps.cuda(), None, None, seed=0, k=1, mode="truncated" ) print(next_tokens)