[Feature] support min_p_sampling (#2872)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* Fastdeploy support min_p

* add test_min_p

* fix

* min_p_sampling

* update

* delete vl_gpu_model_runner.py

* fix

* Align usage of min_p with vLLM

* fix

* modified unit test

* fix test_min_sampling

* pre-commit all files

* fix

* fix

* fix

* fix xpu_model_runner.py
This commit is contained in:
lizexu123
2025-07-21 14:17:59 +08:00
committed by GitHub
parent 95a214ae43
commit 67990e0572
15 changed files with 302 additions and 1 deletions

View File

@@ -0,0 +1,65 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/phi/backends/context_pool.h"
#include "sample_kernels/sampling.cuh"
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
const paddle::Tensor &min_p) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];
auto cu_stream = probs.stream();
auto renorm_probs =
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
cudaError_t status;
status = sampling::MinPSamplingFromProb<float, int>(
const_cast<float *>(probs.data<float>()),
const_cast<float *>(min_p.data<float>()),
renorm_probs.data<float>(),
batch_size,
vocab_size,
true, // deterministic
cu_stream);
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
return {renorm_probs};
}
std::vector<std::vector<int64_t>>
MinPSamplingFromProbsInferShape(const std::vector<int64_t> &probs_shape,
const paddle::optional<std::vector<int64_t>> &min_p_shape) {
return {probs_shape};
}
std::vector<paddle::DataType>
MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype,
const paddle::optional<paddle::DataType> &min_p_dtype) {
return {probs_dtype};
}
PD_BUILD_STATIC_OP(min_p_sampling)
.Inputs({"probs", "min_p"})
.Outputs({"renorm_probs"})
.SetKernelFn(PD_KERNEL(MinPSamplingFromProbs))
.SetInferShapeFn(PD_INFER_SHAPE(MinPSamplingFromProbsInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MinPSamplingFromProbsInferDtype));

View File

@@ -276,6 +276,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
aggregate += aggregate_local;
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
@@ -391,6 +394,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
}
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
bool DETERMINISTIC, typename DType, typename IdType>
@@ -553,6 +558,47 @@ struct RenormTempStorage {
};
};
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType,typename IdType>
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
DType* renormed_prob,uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
const uint32_t row_idx = bx;
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
float pivot = max_val * p;
vec_t<float, VEC_SIZE> probs_vec;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
}
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
@@ -705,6 +751,33 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
return cudaSuccess;
}
template <typename T,typename IdType>
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
uint32_t batch_size,
uint32_t d, bool deterministic,
cudaStream_t stream = 0){
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE,
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel =
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T,IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
smem_size, stream));
})});
return cudaSuccess;
}
template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,