mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feature] support min_p_sampling (#2872)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* Fastdeploy support min_p * add test_min_p * fix * min_p_sampling * update * delete vl_gpu_model_runner.py * fix * Align usage of min_p with vLLM * fix * modified unit test * fix test_min_sampling * pre-commit all files * fix * fix * fix * fix xpu_model_runner.py
This commit is contained in:
@@ -0,0 +1,65 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
#include "paddle/phi/backends/context_pool.h"
|
||||||
|
#include "sample_kernels/sampling.cuh"
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
|
||||||
|
const paddle::Tensor &min_p) {
|
||||||
|
std::vector<int64_t> probs_shape = probs.shape();
|
||||||
|
unsigned int batch_size = probs_shape[0];
|
||||||
|
unsigned int vocab_size = probs_shape[1];
|
||||||
|
auto cu_stream = probs.stream();
|
||||||
|
|
||||||
|
auto renorm_probs =
|
||||||
|
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
|
||||||
|
|
||||||
|
cudaError_t status;
|
||||||
|
|
||||||
|
status = sampling::MinPSamplingFromProb<float, int>(
|
||||||
|
const_cast<float *>(probs.data<float>()),
|
||||||
|
const_cast<float *>(min_p.data<float>()),
|
||||||
|
renorm_probs.data<float>(),
|
||||||
|
batch_size,
|
||||||
|
vocab_size,
|
||||||
|
true, // deterministic
|
||||||
|
cu_stream);
|
||||||
|
|
||||||
|
|
||||||
|
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
|
||||||
|
std::string(cudaGetErrorString(status)));
|
||||||
|
|
||||||
|
return {renorm_probs};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>>
|
||||||
|
MinPSamplingFromProbsInferShape(const std::vector<int64_t> &probs_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>> &min_p_shape) {
|
||||||
|
return {probs_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType>
|
||||||
|
MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype,
|
||||||
|
const paddle::optional<paddle::DataType> &min_p_dtype) {
|
||||||
|
return {probs_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(min_p_sampling)
|
||||||
|
.Inputs({"probs", "min_p"})
|
||||||
|
.Outputs({"renorm_probs"})
|
||||||
|
.SetKernelFn(PD_KERNEL(MinPSamplingFromProbs))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(MinPSamplingFromProbsInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(MinPSamplingFromProbsInferDtype));
|
@@ -276,6 +276,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
aggregate += aggregate_local;
|
aggregate += aggregate_local;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||||
typename DType, typename IdType>
|
typename DType, typename IdType>
|
||||||
@@ -391,6 +394,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||||
bool DETERMINISTIC, typename DType, typename IdType>
|
bool DETERMINISTIC, typename DType, typename IdType>
|
||||||
@@ -553,6 +558,47 @@ struct RenormTempStorage {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
|
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||||
|
typename DType,typename IdType>
|
||||||
|
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
||||||
|
DType* renormed_prob,uint32_t d) {
|
||||||
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
|
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
|
||||||
|
const uint32_t row_idx = bx;
|
||||||
|
|
||||||
|
extern __shared__ __align__(
|
||||||
|
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||||
|
uint8_t smem_sampling[];
|
||||||
|
auto& temp_storage =
|
||||||
|
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||||
|
smem_sampling);
|
||||||
|
|
||||||
|
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
||||||
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
|
||||||
|
probs, row_idx, d, temp_storage);
|
||||||
|
float pivot = max_val * p;
|
||||||
|
|
||||||
|
vec_t<float, VEC_SIZE> probs_vec;
|
||||||
|
#pragma unroll 2
|
||||||
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
|
probs_vec.fill(0);
|
||||||
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
|
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
|
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
|
||||||
|
}
|
||||||
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
|
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||||
typename DType, typename IdType>
|
typename DType, typename IdType>
|
||||||
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
|
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
|
||||||
@@ -705,6 +751,33 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
|||||||
return cudaSuccess;
|
return cudaSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T,typename IdType>
|
||||||
|
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
|
||||||
|
uint32_t batch_size,
|
||||||
|
uint32_t d, bool deterministic,
|
||||||
|
cudaStream_t stream = 0){
|
||||||
|
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||||
|
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||||
|
|
||||||
|
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||||
|
dim3 nblks(batch_size);
|
||||||
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
|
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
|
||||||
|
DISPATCH_ALIGNED_VEC_SIZE(
|
||||||
|
vec_size, VEC_SIZE,
|
||||||
|
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||||
|
auto kernel =
|
||||||
|
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
||||||
|
VEC_SIZE, DETERMINISTIC, T,IdType>;
|
||||||
|
CUDA_CALL(cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||||
|
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
|
||||||
|
smem_size, stream));
|
||||||
|
})});
|
||||||
|
return cudaSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename IdType>
|
template <typename T, typename IdType>
|
||||||
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||||
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
|
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
|
||||||
|
@@ -287,6 +287,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
"gpu_ops/text_image_gather_scatter.cu",
|
"gpu_ops/text_image_gather_scatter.cu",
|
||||||
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
||||||
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
||||||
|
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
|
||||||
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||||
"gpu_ops/noaux_tc.cu",
|
"gpu_ops/noaux_tc.cu",
|
||||||
|
@@ -180,6 +180,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
|
|||||||
* temperature(float): Controls randomness (higher = more random)
|
* temperature(float): Controls randomness (higher = more random)
|
||||||
* top_p(float): Probability threshold for token selection
|
* top_p(float): Probability threshold for token selection
|
||||||
* top_k(int): Number of tokens considered for sampling
|
* top_k(int): Number of tokens considered for sampling
|
||||||
|
* min_p(float): Minimum probability relative to the maximum probability for a token to be considered (>0 filters low-probability tokens to improve quality)
|
||||||
* max_tokens(int): Maximum generated tokens (input + output)
|
* max_tokens(int): Maximum generated tokens (input + output)
|
||||||
* min_tokens(int): Minimum forced generation length
|
* min_tokens(int): Minimum forced generation length
|
||||||
|
|
||||||
|
@@ -180,6 +180,7 @@ for output in outputs:
|
|||||||
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
|
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
|
||||||
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
|
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
|
||||||
* top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样
|
* top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样
|
||||||
|
* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值,设为>0可通过过滤低概率token来提升文本生成质量)
|
||||||
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
|
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
|
||||||
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
|
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
|
||||||
|
|
||||||
|
@@ -53,6 +53,9 @@ class SamplingParams:
|
|||||||
top_p: Float that controls the cumulative probability of the top tokens
|
top_p: Float that controls the cumulative probability of the top tokens
|
||||||
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
|
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
|
||||||
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
|
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
|
||||||
|
min_p: Float that represents the minimum probability for a token to be
|
||||||
|
considered, relative to the probability of the most likely token.
|
||||||
|
Must be in [0, 1]. Set to 0 to disable this.
|
||||||
seed: Random seed to use for the generation.
|
seed: Random seed to use for the generation.
|
||||||
stop: list of strings that stop the generation when they are generated.
|
stop: list of strings that stop the generation when they are generated.
|
||||||
The returned output will not contain the stop strings.
|
The returned output will not contain the stop strings.
|
||||||
@@ -84,6 +87,7 @@ class SamplingParams:
|
|||||||
temperature: float = None
|
temperature: float = None
|
||||||
top_p: float = None
|
top_p: float = None
|
||||||
top_k: int = 0
|
top_k: int = 0
|
||||||
|
min_p: float = 0.0
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
@@ -114,6 +118,7 @@ class SamplingParams:
|
|||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
|
min_p,
|
||||||
seed=None,
|
seed=None,
|
||||||
stop=None,
|
stop=None,
|
||||||
stop_token_ids=None,
|
stop_token_ids=None,
|
||||||
@@ -133,6 +138,7 @@ class SamplingParams:
|
|||||||
temperature=temperature if temperature is not None else 1.0,
|
temperature=temperature if temperature is not None else 1.0,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k if top_k is not None else 0,
|
top_k=top_k if top_k is not None else 0,
|
||||||
|
min_p=min_p if min_p is not None else 0.0,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
@@ -170,6 +176,8 @@ class SamplingParams:
|
|||||||
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
|
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
|
||||||
if not isinstance(self.top_k, int):
|
if not isinstance(self.top_k, int):
|
||||||
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
|
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
|
||||||
|
if not 0.0 <= self.min_p <= 1.0:
|
||||||
|
raise ValueError("min_p must be in [0,1],got f{self.min_p}")
|
||||||
|
|
||||||
if self.max_tokens is not None and self.max_tokens < 1:
|
if self.max_tokens is not None and self.max_tokens < 1:
|
||||||
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
|
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||||
|
@@ -339,6 +339,7 @@ class CompletionRequest(BaseModel):
|
|||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
|
min_p: Optional[float] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
response_format: Optional[AnyResponseFormat] = None
|
response_format: Optional[AnyResponseFormat] = None
|
||||||
@@ -460,6 +461,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
|
min_p: Optional[float] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
@@ -42,6 +42,7 @@ class SamplingMetadata:
|
|||||||
|
|
||||||
top_p: paddle.Tensor
|
top_p: paddle.Tensor
|
||||||
top_k: Optional[paddle.Tensor] = None
|
top_k: Optional[paddle.Tensor] = None
|
||||||
|
min_p: Optional[paddle.Tensor] = None
|
||||||
max_num_logprobs: Optional[int] = None
|
max_num_logprobs: Optional[int] = None
|
||||||
prompt_ids: Optional[paddle.Tensor] = None
|
prompt_ids: Optional[paddle.Tensor] = None
|
||||||
prompt_lens: Optional[paddle.Tensor] = None
|
prompt_lens: Optional[paddle.Tensor] = None
|
||||||
|
@@ -18,10 +18,11 @@ from .apply_penalty_multi_scores import (
|
|||||||
apply_penalty_multi_scores,
|
apply_penalty_multi_scores,
|
||||||
apply_speculative_penalty_multi_scores,
|
apply_speculative_penalty_multi_scores,
|
||||||
)
|
)
|
||||||
from .top_k_top_p_sampling import top_k_top_p_sampling
|
from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"apply_penalty_multi_scores",
|
"apply_penalty_multi_scores",
|
||||||
"apply_speculative_penalty_multi_scores",
|
"apply_speculative_penalty_multi_scores",
|
||||||
"top_k_top_p_sampling",
|
"top_k_top_p_sampling",
|
||||||
|
"min_p_sampling",
|
||||||
]
|
]
|
||||||
|
@@ -60,6 +60,7 @@ def top_k_top_p_sampling(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
||||||
|
|
||||||
if top_p_class == "air":
|
if top_p_class == "air":
|
||||||
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
|
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
|
||||||
elif top_p_class == "rejection":
|
elif top_p_class == "rejection":
|
||||||
@@ -154,3 +155,25 @@ def rejection_top_p_sampling(
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def min_p_sampling(
|
||||||
|
probs: paddle.tensor,
|
||||||
|
min_p_arr: Optional[paddle.Tensor],
|
||||||
|
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""
|
||||||
|
min_p_sampling
|
||||||
|
"""
|
||||||
|
if paddle.count_nonzero(min_p_arr) == 0:
|
||||||
|
return probs
|
||||||
|
else:
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
from fastdeploy.model_executor.ops.gpu import min_p_sampling
|
||||||
|
|
||||||
|
probs = min_p_sampling(probs, min_p_arr)
|
||||||
|
else:
|
||||||
|
max_probabilities = paddle.amax(probs, axis=-1, keepdim=True)
|
||||||
|
adjusted_min_p = max_probabilities * min_p_arr
|
||||||
|
invalid_token_mask = probs < adjusted_min_p.reshape([-1, 1])
|
||||||
|
probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs)
|
||||||
|
return probs
|
||||||
|
@@ -30,6 +30,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
|||||||
from fastdeploy.model_executor.layers.sample.ops import (
|
from fastdeploy.model_executor.layers.sample.ops import (
|
||||||
apply_penalty_multi_scores,
|
apply_penalty_multi_scores,
|
||||||
apply_speculative_penalty_multi_scores,
|
apply_speculative_penalty_multi_scores,
|
||||||
|
min_p_sampling,
|
||||||
top_k_top_p_sampling,
|
top_k_top_p_sampling,
|
||||||
)
|
)
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
@@ -266,6 +267,8 @@ class Sampler(nn.Layer):
|
|||||||
|
|
||||||
probs = F.softmax(logits)
|
probs = F.softmax(logits)
|
||||||
|
|
||||||
|
probs = min_p_sampling(probs, sampling_metadata.min_p)
|
||||||
|
|
||||||
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||||
|
|
||||||
logprobs_tensors = (
|
logprobs_tensors = (
|
||||||
@@ -281,6 +284,7 @@ class Sampler(nn.Layer):
|
|||||||
sampled_token_ids=next_tokens,
|
sampled_token_ids=next_tokens,
|
||||||
logprobs_tensors=logprobs_tensors,
|
logprobs_tensors=logprobs_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
|
|
||||||
|
@@ -320,6 +320,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||||
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
|
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
|
||||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||||
|
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||||
|
|
||||||
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
|
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
|
||||||
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
|
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
|
||||||
request, "repetition_penalty", 1.0
|
request, "repetition_penalty", 1.0
|
||||||
@@ -430,6 +432,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
|
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
|
||||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
||||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||||
|
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
|
||||||
self.share_inputs["temperature"] = paddle.full(
|
self.share_inputs["temperature"] = paddle.full(
|
||||||
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
||||||
)
|
)
|
||||||
@@ -626,6 +629,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
temperature=self.share_inputs["temperature"],
|
temperature=self.share_inputs["temperature"],
|
||||||
top_p=self.share_inputs["top_p"],
|
top_p=self.share_inputs["top_p"],
|
||||||
top_k=self.share_inputs["top_k"],
|
top_k=self.share_inputs["top_k"],
|
||||||
|
min_p=self.share_inputs["min_p"],
|
||||||
step_idx=self.share_inputs["step_idx"],
|
step_idx=self.share_inputs["step_idx"],
|
||||||
pre_token_ids=self.share_inputs["pre_ids"],
|
pre_token_ids=self.share_inputs["pre_ids"],
|
||||||
prompt_ids=self.share_inputs["prompt_ids"],
|
prompt_ids=self.share_inputs["prompt_ids"],
|
||||||
|
@@ -304,6 +304,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
||||||
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
|
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
|
||||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||||
|
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||||
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
|
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
|
||||||
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
||||||
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
||||||
@@ -363,6 +364,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
|
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
|
||||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
||||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||||
|
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
|
||||||
self.share_inputs["temperature"] = paddle.full(
|
self.share_inputs["temperature"] = paddle.full(
|
||||||
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
||||||
)
|
)
|
||||||
@@ -473,6 +475,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
temperature=self.share_inputs["temperature"],
|
temperature=self.share_inputs["temperature"],
|
||||||
top_p=self.share_inputs["top_p"],
|
top_p=self.share_inputs["top_p"],
|
||||||
top_k=self.share_inputs["top_k"],
|
top_k=self.share_inputs["top_k"],
|
||||||
|
min_p=self.share_inputs["min_p"],
|
||||||
step_idx=self.share_inputs["step_idx"],
|
step_idx=self.share_inputs["step_idx"],
|
||||||
pre_token_ids=self.share_inputs["pre_ids"],
|
pre_token_ids=self.share_inputs["pre_ids"],
|
||||||
frequency_penalties=self.share_inputs["frequency_score"],
|
frequency_penalties=self.share_inputs["frequency_score"],
|
||||||
|
113
test/layers/test_min_sampling.py
Normal file
113
test/layers/test_min_sampling.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.gpu import min_p_sampling
|
||||||
|
|
||||||
|
|
||||||
|
class TestMinPSampling(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.sample_time = 1000000
|
||||||
|
self.vocab_size = 1000
|
||||||
|
self.min_p_value = 0.5
|
||||||
|
self.batch_size = 3
|
||||||
|
self.batch_min_p_values = [0.1, 0.0, 0.9]
|
||||||
|
self.additional_batch_min_p_values = [0.1, 0.0, 0.3]
|
||||||
|
|
||||||
|
# min_p:0.5:FastDeploy
|
||||||
|
def min_p_sampling_cpu(self, min_p):
|
||||||
|
logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32")
|
||||||
|
logits[0][0] = 10
|
||||||
|
logits[0][1] = 8
|
||||||
|
low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
|
||||||
|
logits[0][2:] = low_prob_tensor
|
||||||
|
|
||||||
|
probs = F.softmax(logits)
|
||||||
|
max_probabilities = paddle.amax(probs, axis=-1, keepdim=True)
|
||||||
|
adjusted_min_p = max_probabilities * min_p.reshape([-1, 1])
|
||||||
|
invalid_token_mask = probs < adjusted_min_p
|
||||||
|
probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
# min_p:0.5:FastDeploy
|
||||||
|
def fastdeploy_min_p_sampling(self, min_p):
|
||||||
|
logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32")
|
||||||
|
logits[0][0] = 10
|
||||||
|
logits[0][1] = 8
|
||||||
|
low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
|
||||||
|
logits[0][2:] = low_prob_tensor
|
||||||
|
|
||||||
|
probs = F.softmax(logits)
|
||||||
|
probs = min_p_sampling(probs, min_p)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
# batch:[0.1.0.0,0.9]:FastDeploy
|
||||||
|
def fastdeploy_batch_min_p_sampling(self, batch_size, min_p_values):
|
||||||
|
logits = paddle.ones(shape=[batch_size, self.vocab_size], dtype="float32")
|
||||||
|
for b in range(batch_size):
|
||||||
|
logits[b][0] = 10
|
||||||
|
logits[b][1] = 8
|
||||||
|
logits[b][2:] = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
|
||||||
|
|
||||||
|
probs = F.softmax(logits, axis=-1)
|
||||||
|
min_p_arr = paddle.to_tensor(min_p_values, dtype="float32")
|
||||||
|
|
||||||
|
probs = min_p_sampling(probs, min_p_arr)
|
||||||
|
|
||||||
|
return probs
|
||||||
|
|
||||||
|
def compare_results(self, probs, probs_cpu, atol=1e-6, rtol=1e-6):
|
||||||
|
probs_np = probs.numpy()
|
||||||
|
probs_cpu_np = probs_cpu.numpy()
|
||||||
|
try:
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
probs_np,
|
||||||
|
probs_cpu_np,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
)
|
||||||
|
print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu")
|
||||||
|
except AssertionError as e:
|
||||||
|
raise AssertionError(
|
||||||
|
f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_single_min_p_sampling(self):
|
||||||
|
min_p = paddle.to_tensor([self.min_p_value], dtype="float32")
|
||||||
|
probs = self.fastdeploy_min_p_sampling(min_p)
|
||||||
|
probs_cpu = self.min_p_sampling_cpu(min_p)
|
||||||
|
self.compare_results(probs, probs_cpu)
|
||||||
|
|
||||||
|
def test_batch_min_p_sampling(self):
|
||||||
|
batch_min_p = paddle.to_tensor(self.batch_min_p_values, dtype="float32")
|
||||||
|
batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, batch_min_p)
|
||||||
|
batch_probs_cpu = self.min_p_sampling_cpu(batch_min_p)
|
||||||
|
self.compare_results(batch_probs, batch_probs_cpu)
|
||||||
|
|
||||||
|
def test_additional_batch_min_p_sampling(self):
|
||||||
|
additional_batch_min_p = paddle.to_tensor(self.additional_batch_min_p_values, dtype="float32")
|
||||||
|
additional_batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, additional_batch_min_p)
|
||||||
|
additional_batch_probs_cpu = self.min_p_sampling_cpu(additional_batch_min_p)
|
||||||
|
self.compare_results(additional_batch_probs, additional_batch_probs_cpu)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
unittest.main()
|
@@ -73,5 +73,6 @@ def test_sampler():
|
|||||||
print(next_tokens)
|
print(next_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_sampler()
|
test_sampler()
|
||||||
|
Reference in New Issue
Block a user