From 7c5e34e72d6511024b7be57e6abd728e8834d491 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com> Date: Tue, 22 Jul 2025 20:53:37 +0800 Subject: [PATCH] [FIX]fix rejection sampling when topp=0 using _SAMPLING_EPS (#2967) * fix rejection sampling when topp=0 * fix --- custom_ops/gpu_ops/sample_kernels/sampling.cuh | 2 +- fastdeploy/input/ernie_processor.py | 4 ++++ fastdeploy/input/text_processor.py | 4 ++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/custom_ops/gpu_ops/sample_kernels/sampling.cuh b/custom_ops/gpu_ops/sample_kernels/sampling.cuh index f14694fa1..e8c70398f 100644 --- a/custom_ops/gpu_ops/sample_kernels/sampling.cuh +++ b/custom_ops/gpu_ops/sample_kernels/sampling.cuh @@ -292,7 +292,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = bx; const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx]; - const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx]; + const float p = top_p_arr[row_idx]; extern __shared__ __align__( alignof(SamplingTempStorage)) diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 1ccf3e13f..a56c7f9fb 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -123,6 +123,8 @@ class ErnieProcessor(BaseDataProcessor): if request.get("temperature") < _SAMPLING_EPS: # zero temperature is equivalent to greedy sampling request.set("temperature", 1) + if request.get("top_p") < _SAMPLING_EPS: + request.set("top_p", _SAMPLING_EPS) data_processor_logger.info(f"Processed request {request}") return request @@ -174,6 +176,8 @@ class ErnieProcessor(BaseDataProcessor): if request.get("temperature") < _SAMPLING_EPS: # zero temperature is equivalent to greedy sampling request["temperature"] = 1 + if request.get("top_p") < _SAMPLING_EPS: + request["top_p"] = _SAMPLING_EPS data_processor_logger.info(f"Processed request {request}") return request diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index d4d70bbc3..a9f8c2c49 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -252,6 +252,8 @@ class DataProcessor(BaseDataProcessor): if request.get("temperature") < _SAMPLING_EPS: # zero temperature is equivalent to greedy sampling request.set("temperature", 1) + if request.get("top_p") < _SAMPLING_EPS: + request.set("top_p", _SAMPLING_EPS) data_processor_logger.info(f"Processed request {request}") return request @@ -297,6 +299,8 @@ class DataProcessor(BaseDataProcessor): if request.get("temperature") < _SAMPLING_EPS: # zero temperature is equivalent to greedy sampling request["temperature"] = 1 + if request.get("top_p") < _SAMPLING_EPS: + request["top_p"] = _SAMPLING_EPS data_processor_logger.info(f"Processed request {request}") return request