[Feature] support min_p_sampling (#2872)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* Fastdeploy support min_p

* add test_min_p

* fix

* min_p_sampling

* update

* delete vl_gpu_model_runner.py

* fix

* Align usage of min_p with vLLM

* fix

* modified unit test

* fix test_min_sampling

* pre-commit all files

* fix

* fix

* fix

* fix xpu_model_runner.py
This commit is contained in:
lizexu123
2025-07-21 14:17:59 +08:00
committed by GitHub
parent 95a214ae43
commit 67990e0572
15 changed files with 302 additions and 1 deletions

View File

@@ -0,0 +1,65 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/phi/backends/context_pool.h"
#include "sample_kernels/sampling.cuh"
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
const paddle::Tensor &min_p) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];
auto cu_stream = probs.stream();
auto renorm_probs =
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
cudaError_t status;
status = sampling::MinPSamplingFromProb<float, int>(
const_cast<float *>(probs.data<float>()),
const_cast<float *>(min_p.data<float>()),
renorm_probs.data<float>(),
batch_size,
vocab_size,
true, // deterministic
cu_stream);
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
return {renorm_probs};
}
std::vector<std::vector<int64_t>>
MinPSamplingFromProbsInferShape(const std::vector<int64_t> &probs_shape,
const paddle::optional<std::vector<int64_t>> &min_p_shape) {
return {probs_shape};
}
std::vector<paddle::DataType>
MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype,
const paddle::optional<paddle::DataType> &min_p_dtype) {
return {probs_dtype};
}
PD_BUILD_STATIC_OP(min_p_sampling)
.Inputs({"probs", "min_p"})
.Outputs({"renorm_probs"})
.SetKernelFn(PD_KERNEL(MinPSamplingFromProbs))
.SetInferShapeFn(PD_INFER_SHAPE(MinPSamplingFromProbsInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MinPSamplingFromProbsInferDtype));

View File

@@ -276,6 +276,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
aggregate += aggregate_local;
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
@@ -391,6 +394,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
}
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
bool DETERMINISTIC, typename DType, typename IdType>
@@ -553,6 +558,47 @@ struct RenormTempStorage {
};
};
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType,typename IdType>
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
DType* renormed_prob,uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
const uint32_t row_idx = bx;
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
float pivot = max_val * p;
vec_t<float, VEC_SIZE> probs_vec;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
}
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
@@ -705,6 +751,33 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
return cudaSuccess;
}
template <typename T,typename IdType>
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
uint32_t batch_size,
uint32_t d, bool deterministic,
cudaStream_t stream = 0){
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE,
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel =
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T,IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
smem_size, stream));
})});
return cudaSuccess;
}
template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,

View File

@@ -287,6 +287,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",

View File

@@ -180,6 +180,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
* temperature(float): Controls randomness (higher = more random)
* top_p(float): Probability threshold for token selection
* top_k(int): Number of tokens considered for sampling
* min_p(float): Minimum probability relative to the maximum probability for a token to be considered (>0 filters low-probability tokens to improve quality)
* max_tokens(int): Maximum generated tokens (input + output)
* min_tokens(int): Minimum forced generation length

View File

@@ -180,6 +180,7 @@ for output in outputs:
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
* top_p(float): 概率累积分布截断阈值仅考虑累计概率达到此阈值的最可能token集合
* top_k(int): 采样概率最高的token数量考虑概率最高的k个token进行采样
* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值设为>0可通过过滤低概率token来提升文本生成质量)
* max_tokens(int): 限制模型生成的最大token数量包括输入和输出
* min_tokens(int): 强制模型生成的最少token数量避免过早结束

View File

@@ -53,6 +53,9 @@ class SamplingParams:
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
stop: list of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
@@ -84,6 +87,7 @@ class SamplingParams:
temperature: float = None
top_p: float = None
top_k: int = 0
min_p: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
@@ -114,6 +118,7 @@ class SamplingParams:
temperature,
top_p,
top_k,
min_p,
seed=None,
stop=None,
stop_token_ids=None,
@@ -133,6 +138,7 @@ class SamplingParams:
temperature=temperature if temperature is not None else 1.0,
top_p=top_p,
top_k=top_k if top_k is not None else 0,
min_p=min_p if min_p is not None else 0.0,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
@@ -170,6 +176,8 @@ class SamplingParams:
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
if not isinstance(self.top_k, int):
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0,1],got f{self.min_p}")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")

View File

@@ -339,6 +339,7 @@ class CompletionRequest(BaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
min_p: Optional[float] = None
user: Optional[str] = None
response_format: Optional[AnyResponseFormat] = None
@@ -460,6 +461,7 @@ class ChatCompletionRequest(BaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
min_p: Optional[float] = None
user: Optional[str] = None
metadata: Optional[dict] = None

View File

@@ -42,6 +42,7 @@ class SamplingMetadata:
top_p: paddle.Tensor
top_k: Optional[paddle.Tensor] = None
min_p: Optional[paddle.Tensor] = None
max_num_logprobs: Optional[int] = None
prompt_ids: Optional[paddle.Tensor] = None
prompt_lens: Optional[paddle.Tensor] = None

View File

@@ -18,10 +18,11 @@ from .apply_penalty_multi_scores import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
)
from .top_k_top_p_sampling import top_k_top_p_sampling
from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling
__all__ = [
"apply_penalty_multi_scores",
"apply_speculative_penalty_multi_scores",
"top_k_top_p_sampling",
"min_p_sampling",
]

View File

@@ -60,6 +60,7 @@ def top_k_top_p_sampling(
"""
top_p_class = envs.FD_SAMPLING_CLASS.lower()
if top_p_class == "air":
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
elif top_p_class == "rejection":
@@ -154,3 +155,25 @@ def rejection_top_p_sampling(
except ImportError:
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
return ids
def min_p_sampling(
probs: paddle.tensor,
min_p_arr: Optional[paddle.Tensor],
) -> tuple[paddle.Tensor, paddle.Tensor]:
"""
min_p_sampling
"""
if paddle.count_nonzero(min_p_arr) == 0:
return probs
else:
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import min_p_sampling
probs = min_p_sampling(probs, min_p_arr)
else:
max_probabilities = paddle.amax(probs, axis=-1, keepdim=True)
adjusted_min_p = max_probabilities * min_p_arr
invalid_token_mask = probs < adjusted_min_p.reshape([-1, 1])
probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs)
return probs

View File

@@ -30,6 +30,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
min_p_sampling,
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
@@ -266,6 +267,8 @@ class Sampler(nn.Layer):
probs = F.softmax(logits)
probs = min_p_sampling(probs, sampling_metadata.min_p)
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
logprobs_tensors = (
@@ -281,6 +284,7 @@ class Sampler(nn.Layer):
sampled_token_ids=next_tokens,
logprobs_tensors=logprobs_tensors,
)
return sampler_output

View File

@@ -320,6 +320,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
request, "repetition_penalty", 1.0
@@ -430,6 +432,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
)
@@ -626,6 +629,7 @@ class GPUModelRunner(ModelRunnerBase):
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
min_p=self.share_inputs["min_p"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"],

View File

@@ -304,6 +304,7 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
@@ -363,6 +364,7 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
)
@@ -473,6 +475,7 @@ class XPUModelRunner(ModelRunnerBase):
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
min_p=self.share_inputs["min_p"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"],

View File

@@ -0,0 +1,113 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from fastdeploy.model_executor.ops.gpu import min_p_sampling
class TestMinPSampling(unittest.TestCase):
def setUp(self):
self.sample_time = 1000000
self.vocab_size = 1000
self.min_p_value = 0.5
self.batch_size = 3
self.batch_min_p_values = [0.1, 0.0, 0.9]
self.additional_batch_min_p_values = [0.1, 0.0, 0.3]
# min_p:0.5FastDeploy
def min_p_sampling_cpu(self, min_p):
logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32")
logits[0][0] = 10
logits[0][1] = 8
low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
logits[0][2:] = low_prob_tensor
probs = F.softmax(logits)
max_probabilities = paddle.amax(probs, axis=-1, keepdim=True)
adjusted_min_p = max_probabilities * min_p.reshape([-1, 1])
invalid_token_mask = probs < adjusted_min_p
probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs)
return probs
# min_p:0.5FastDeploy
def fastdeploy_min_p_sampling(self, min_p):
logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32")
logits[0][0] = 10
logits[0][1] = 8
low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
logits[0][2:] = low_prob_tensor
probs = F.softmax(logits)
probs = min_p_sampling(probs, min_p)
return probs
# batch:[0.1.0.0,0.9]FastDeploy
def fastdeploy_batch_min_p_sampling(self, batch_size, min_p_values):
logits = paddle.ones(shape=[batch_size, self.vocab_size], dtype="float32")
for b in range(batch_size):
logits[b][0] = 10
logits[b][1] = 8
logits[b][2:] = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
probs = F.softmax(logits, axis=-1)
min_p_arr = paddle.to_tensor(min_p_values, dtype="float32")
probs = min_p_sampling(probs, min_p_arr)
return probs
def compare_results(self, probs, probs_cpu, atol=1e-6, rtol=1e-6):
probs_np = probs.numpy()
probs_cpu_np = probs_cpu.numpy()
try:
np.testing.assert_allclose(
probs_np,
probs_cpu_np,
rtol=rtol,
atol=atol,
)
print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu")
except AssertionError as e:
raise AssertionError(
f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}"
)
def test_single_min_p_sampling(self):
min_p = paddle.to_tensor([self.min_p_value], dtype="float32")
probs = self.fastdeploy_min_p_sampling(min_p)
probs_cpu = self.min_p_sampling_cpu(min_p)
self.compare_results(probs, probs_cpu)
def test_batch_min_p_sampling(self):
batch_min_p = paddle.to_tensor(self.batch_min_p_values, dtype="float32")
batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, batch_min_p)
batch_probs_cpu = self.min_p_sampling_cpu(batch_min_p)
self.compare_results(batch_probs, batch_probs_cpu)
def test_additional_batch_min_p_sampling(self):
additional_batch_min_p = paddle.to_tensor(self.additional_batch_min_p_values, dtype="float32")
additional_batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, additional_batch_min_p)
additional_batch_probs_cpu = self.min_p_sampling_cpu(additional_batch_min_p)
self.compare_results(additional_batch_probs, additional_batch_probs_cpu)
if __name__ == "__main__":
if paddle.is_compiled_with_cuda():
unittest.main()

View File

@@ -73,5 +73,6 @@ def test_sampler():
print(next_tokens)
if __name__ == "__main__":
test_sampler()