mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-29 22:02:30 +08:00
[Feature] support top_k_top_p sampling (#2753)
* support top_k_top_p sampling * fix * add api param * add api para * fix * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||||
const paddle::Tensor &top_p,
|
const paddle::Tensor &top_p,
|
||||||
|
const paddle::optional<paddle::Tensor> &top_k,
|
||||||
int seed) {
|
int seed) {
|
||||||
std::vector<int64_t> probs_shape = probs.shape();
|
std::vector<int64_t> probs_shape = probs.shape();
|
||||||
unsigned int batch_size = probs_shape[0];
|
unsigned int batch_size = probs_shape[0];
|
||||||
@@ -40,10 +41,18 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
|||||||
|
|
||||||
cudaError_t status;
|
cudaError_t status;
|
||||||
|
|
||||||
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
|
if (top_k) {
|
||||||
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
|
||||||
batch_size, top_p.data<float>(), vocab_size,
|
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
||||||
true, philox_seed, philox_offset, cu_stream);
|
batch_size, top_p.data<float>(), top_k.get().data<int64_t>(), vocab_size,
|
||||||
|
true, philox_seed, philox_offset, cu_stream);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
status = sampling::TopPSamplingFromProb<float, int64_t>(
|
||||||
|
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
||||||
|
batch_size, top_p.data<float>(), vocab_size,
|
||||||
|
true, philox_seed, philox_offset, cu_stream);
|
||||||
|
}
|
||||||
|
|
||||||
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
|
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
|
||||||
std::string(cudaGetErrorString(status)));
|
std::string(cudaGetErrorString(status)));
|
||||||
@@ -53,19 +62,21 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
|||||||
|
|
||||||
std::vector<std::vector<int64_t>>
|
std::vector<std::vector<int64_t>>
|
||||||
TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape,
|
TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape,
|
||||||
const std::vector<int64_t> &top_p_shape) {
|
const std::vector<int64_t> &top_p_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>> &top_k_shape) {
|
||||||
int64_t bs = probs_shape[0];
|
int64_t bs = probs_shape[0];
|
||||||
return {{bs, 1}};
|
return {{bs, 1}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType>
|
std::vector<paddle::DataType>
|
||||||
TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype,
|
TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype,
|
||||||
const paddle::DataType &top_p_shape) {
|
const paddle::DataType &top_p_dtype,
|
||||||
|
const paddle::optional<paddle::DataType> &top_k_dtype) {
|
||||||
return {paddle::DataType::INT64};
|
return {paddle::DataType::INT64};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(rejection_top_p_sampling)
|
PD_BUILD_STATIC_OP(rejection_top_p_sampling)
|
||||||
.Inputs({"probs", "top_p"})
|
.Inputs({"probs", "top_p", paddle::Optional("top_k")})
|
||||||
.Outputs({"samples"})
|
.Outputs({"samples"})
|
||||||
.Attrs({"seed: int"})
|
.Attrs({"seed: int"})
|
||||||
.SetKernelFn(PD_KERNEL(TopPSamplingReject))
|
.SetKernelFn(PD_KERNEL(TopPSamplingReject))
|
||||||
|
@@ -279,7 +279,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||||
typename DType, typename IdType>
|
typename DType, typename IdType>
|
||||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr,
|
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||||
|
float* top_p_arr, IdType* top_k_arr,
|
||||||
uint32_t d, uint64_t philox_seed,
|
uint32_t d, uint64_t philox_seed,
|
||||||
uint64_t philox_offset) {
|
uint64_t philox_offset) {
|
||||||
const uint32_t batch_size = gridDim.x;
|
const uint32_t batch_size = gridDim.x;
|
||||||
@@ -287,7 +288,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, flo
|
|||||||
curandStatePhilox4_32_10_t state;
|
curandStatePhilox4_32_10_t state;
|
||||||
curand_init(philox_seed, bx, philox_offset, &state);
|
curand_init(philox_seed, bx, philox_offset, &state);
|
||||||
const uint32_t row_idx = bx;
|
const uint32_t row_idx = bx;
|
||||||
const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20;
|
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||||
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
|
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
|
||||||
|
|
||||||
extern __shared__ __align__(
|
extern __shared__ __align__(
|
||||||
@@ -479,7 +480,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
if (aggregate_gt_pivot_0 < top_p) {
|
if (aggregate_gt_pivot_0 < top_p) {
|
||||||
// case 1: pivot_0 accepted
|
// case 1: pivot_0 accepted
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (aggregate_gt_pivot_1 < top_p) {
|
if (aggregate_gt_pivot_1 < top_p) {
|
||||||
// case 2: pivot_0 rejected, pivot_1 accepted
|
// case 2: pivot_0 rejected, pivot_1 accepted
|
||||||
low = pivot_0;
|
low = pivot_0;
|
||||||
@@ -497,6 +498,183 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||||
|
typename TempStorage>
|
||||||
|
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
|
||||||
|
TempStorage& temp_storage) {
|
||||||
|
const uint32_t tx = threadIdx.x;
|
||||||
|
vec_t<float, VEC_SIZE> in_data_vec;
|
||||||
|
|
||||||
|
float max_val = 0;
|
||||||
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
|
in_data_vec.fill(0);
|
||||||
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
|
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||||
|
}
|
||||||
|
float in_data_[VEC_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
|
in_data_[j] = in_data_vec[j];
|
||||||
|
}
|
||||||
|
max_val = max(
|
||||||
|
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||||
|
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
if (tx == 0) {
|
||||||
|
temp_storage.max_val = max_val;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
return temp_storage.max_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||||
|
struct RenormTempStorage {
|
||||||
|
union {
|
||||||
|
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
||||||
|
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
||||||
|
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||||
|
reduce_value_count;
|
||||||
|
} block_prim;
|
||||||
|
struct {
|
||||||
|
float max_val;
|
||||||
|
float min_val;
|
||||||
|
union {
|
||||||
|
struct {
|
||||||
|
float values[2];
|
||||||
|
};
|
||||||
|
struct {
|
||||||
|
int counts[2];
|
||||||
|
};
|
||||||
|
struct {
|
||||||
|
ValueCount<float> pairs[2];
|
||||||
|
};
|
||||||
|
} block_aggregate;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||||
|
typename DType, typename IdType>
|
||||||
|
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
|
||||||
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
|
const uint32_t row_idx = bx;
|
||||||
|
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||||
|
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||||
|
vec_t<float, VEC_SIZE> probs_vec;
|
||||||
|
if (k < d) {
|
||||||
|
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
|
||||||
|
uint8_t smem_renorm[];
|
||||||
|
auto& temp_storage =
|
||||||
|
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
|
||||||
|
temp_storage.max_val = 0;
|
||||||
|
|
||||||
|
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
||||||
|
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
||||||
|
probs, row_idx, d, temp_storage);
|
||||||
|
|
||||||
|
double low = 0, high = max_val;
|
||||||
|
float min_gt_low, max_le_high;
|
||||||
|
float sum_low = 1;
|
||||||
|
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
||||||
|
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
|
||||||
|
// loop invariant:
|
||||||
|
// - f(low) >= k, f(high) < k
|
||||||
|
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
|
||||||
|
// stopping condition: min_gt_low == max_le_high
|
||||||
|
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
|
||||||
|
do {
|
||||||
|
double pivot_0 = (high + 2 * low) / 3;
|
||||||
|
double pivot_1 = (2 * high + low) / 3;
|
||||||
|
|
||||||
|
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
|
||||||
|
min_gt_low = high;
|
||||||
|
max_le_high = low;
|
||||||
|
#pragma unroll 2
|
||||||
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
|
probs_vec.fill(0);
|
||||||
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
|
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||||
|
}
|
||||||
|
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
|
probs_gt_pivot_0_pair[j] = {
|
||||||
|
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||||
|
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||||
|
probs_gt_pivot_1_pair[j] = {
|
||||||
|
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||||
|
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||||
|
|
||||||
|
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||||
|
min_gt_low = min(min_gt_low, probs_vec[j]);
|
||||||
|
}
|
||||||
|
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||||
|
max_le_high = max(max_le_high, probs_vec[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
|
temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
|
temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
min_gt_low =
|
||||||
|
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||||
|
.Reduce(min_gt_low, cub::Min());
|
||||||
|
__syncthreads();
|
||||||
|
max_le_high =
|
||||||
|
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||||
|
.Reduce(max_le_high, cub::Max());
|
||||||
|
if (tx == 0) {
|
||||||
|
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
|
||||||
|
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
|
||||||
|
temp_storage.min_val = min_gt_low;
|
||||||
|
temp_storage.max_val = max_le_high;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0];
|
||||||
|
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1];
|
||||||
|
min_gt_low = temp_storage.min_val;
|
||||||
|
max_le_high = temp_storage.max_val;
|
||||||
|
|
||||||
|
if (aggregate_gt_pivot_1.count >= k) {
|
||||||
|
low = pivot_1;
|
||||||
|
sum_low = float(aggregate_gt_pivot_1.value);
|
||||||
|
} else if (aggregate_gt_pivot_0.count >= k) {
|
||||||
|
low = pivot_0;
|
||||||
|
high = min(pivot_1, max_le_high);
|
||||||
|
sum_low = float(aggregate_gt_pivot_0.value);
|
||||||
|
} else {
|
||||||
|
high = min(pivot_0, max_le_high);
|
||||||
|
}
|
||||||
|
} while (min_gt_low != max_le_high);
|
||||||
|
|
||||||
|
normalizer = ptx_rcp(max(sum_low, 1e-8));
|
||||||
|
pivot = low;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize
|
||||||
|
#pragma unroll 2
|
||||||
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
|
probs_vec.fill(0);
|
||||||
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
|
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
|
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
|
||||||
|
}
|
||||||
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
|
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename IdType>
|
template <typename T, typename IdType>
|
||||||
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||||
uint32_t batch_size, const T *top_p_val,
|
uint32_t batch_size, const T *top_p_val,
|
||||||
@@ -529,7 +707,7 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
|||||||
|
|
||||||
template <typename T, typename IdType>
|
template <typename T, typename IdType>
|
||||||
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||||
uint32_t batch_size, const T *top_p_val,
|
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
|
||||||
uint32_t d, bool deterministic,
|
uint32_t d, bool deterministic,
|
||||||
uint64_t philox_seed, uint64_t philox_offset,
|
uint64_t philox_seed, uint64_t philox_offset,
|
||||||
cudaStream_t stream = 0) {
|
cudaStream_t stream = 0) {
|
||||||
@@ -540,7 +718,7 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
|||||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||||
dim3 nblks(batch_size);
|
dim3 nblks(batch_size);
|
||||||
dim3 nthrs(BLOCK_THREADS);
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
void* args[] = {&probs, &output, &top_p_val,
|
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
|
||||||
&d, &philox_seed, &philox_offset};
|
&d, &philox_seed, &philox_offset};
|
||||||
|
|
||||||
DISPATCH_ALIGNED_VEC_SIZE(
|
DISPATCH_ALIGNED_VEC_SIZE(
|
||||||
@@ -556,4 +734,26 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sampling
|
template <typename DType, typename IdType>
|
||||||
|
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
|
||||||
|
uint32_t batch_size, uint32_t d,
|
||||||
|
cudaStream_t stream = 0) {
|
||||||
|
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||||
|
|
||||||
|
auto compute_capacity = GetCudaComputeCapability();
|
||||||
|
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||||
|
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
||||||
|
dim3 nblks(batch_size);
|
||||||
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
|
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
|
||||||
|
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
||||||
|
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
|
||||||
|
CUDA_CALL(
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||||
|
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||||
|
});
|
||||||
|
return cudaSuccess;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sampling
|
||||||
|
61
custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu
Normal file
61
custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
#include "paddle/phi/backends/context_pool.h"
|
||||||
|
#include "sample_kernels/sampling.cuh"
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
|
||||||
|
const paddle::Tensor &top_k) {
|
||||||
|
std::vector<int64_t> probs_shape = probs.shape();
|
||||||
|
uint32_t batch_size = probs_shape[0];
|
||||||
|
uint32_t vocab_size = probs_shape[1];
|
||||||
|
auto cu_stream = probs.stream();
|
||||||
|
|
||||||
|
auto renorm_probs =
|
||||||
|
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
|
||||||
|
|
||||||
|
cudaError_t status;
|
||||||
|
|
||||||
|
|
||||||
|
status = sampling::TopKRenormProb<float>(
|
||||||
|
const_cast<float *>(probs.data<float>()),
|
||||||
|
renorm_probs.data<float>(),
|
||||||
|
const_cast<int64_t *>(top_k.data<int64_t>()),
|
||||||
|
batch_size, vocab_size, cu_stream);
|
||||||
|
|
||||||
|
PD_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " +
|
||||||
|
std::string(cudaGetErrorString(status)));
|
||||||
|
|
||||||
|
return {renorm_probs};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>>
|
||||||
|
TopKRenormInferShape(const std::vector<int64_t> &probs_shape,
|
||||||
|
const std::vector<int64_t> &top_k_shape) {
|
||||||
|
return {probs_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType>
|
||||||
|
TopKRenormInferDtype(const paddle::DataType &probs_dtype,
|
||||||
|
const paddle::DataType &top_k_shape) {
|
||||||
|
return {probs_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(top_k_renorm_probs)
|
||||||
|
.Inputs({"probs", "top_k"})
|
||||||
|
.Outputs({"renorm_probs"})
|
||||||
|
.SetKernelFn(PD_KERNEL(TopKRenorm))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(TopKRenormInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(TopKRenormInferDtype));
|
@@ -282,6 +282,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
"gpu_ops/text_image_index_out.cu",
|
"gpu_ops/text_image_index_out.cu",
|
||||||
"gpu_ops/text_image_gather_scatter.cu",
|
"gpu_ops/text_image_gather_scatter.cu",
|
||||||
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
||||||
|
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
||||||
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||||
"gpu_ops/noaux_tc.cu",
|
"gpu_ops/noaux_tc.cu",
|
||||||
|
@@ -178,6 +178,7 @@ for output in outputs:
|
|||||||
* repetition_penalty(float): 直接对重复生成的token进行惩罚的系数(>1时惩罚重复,<1时鼓励重复)
|
* repetition_penalty(float): 直接对重复生成的token进行惩罚的系数(>1时惩罚重复,<1时鼓励重复)
|
||||||
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
|
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
|
||||||
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
|
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
|
||||||
|
* top_k(int): 采样概率最高的的token数量,考虑概率最高的k个token进行采样
|
||||||
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
|
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
|
||||||
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
|
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
|
||||||
|
|
||||||
|
@@ -16,8 +16,8 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
|
@@ -52,6 +52,7 @@ class SamplingParams:
|
|||||||
the model more random. Zero means greedy sampling.
|
the model more random. Zero means greedy sampling.
|
||||||
top_p: Float that controls the cumulative probability of the top tokens
|
top_p: Float that controls the cumulative probability of the top tokens
|
||||||
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
|
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
|
||||||
|
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
|
||||||
seed: Random seed to use for the generation.
|
seed: Random seed to use for the generation.
|
||||||
stop: list of strings that stop the generation when they are generated.
|
stop: list of strings that stop the generation when they are generated.
|
||||||
The returned output will not contain the stop strings.
|
The returned output will not contain the stop strings.
|
||||||
@@ -81,7 +82,8 @@ class SamplingParams:
|
|||||||
frequency_penalty: float = None
|
frequency_penalty: float = None
|
||||||
repetition_penalty: float = None
|
repetition_penalty: float = None
|
||||||
temperature: float = None
|
temperature: float = None
|
||||||
top_p: float = None
|
top_p: float = 1.0
|
||||||
|
top_k: int = 0
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
@@ -111,6 +113,7 @@ class SamplingParams:
|
|||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
|
top_k,
|
||||||
seed=None,
|
seed=None,
|
||||||
stop=None,
|
stop=None,
|
||||||
stop_token_ids=None,
|
stop_token_ids=None,
|
||||||
@@ -129,7 +132,8 @@ class SamplingParams:
|
|||||||
repetition_penalty=repetition_penalty
|
repetition_penalty=repetition_penalty
|
||||||
if repetition_penalty is not None else 1.0,
|
if repetition_penalty is not None else 1.0,
|
||||||
temperature=temperature if temperature is not None else 1.0,
|
temperature=temperature if temperature is not None else 1.0,
|
||||||
top_p=top_p if top_p is not None else 0.7,
|
top_p=top_p if top_p is not None else 1.0,
|
||||||
|
top_k=top_k if top_k is not None else 0,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
|
@@ -292,6 +292,7 @@ class CompletionRequest(BaseModel):
|
|||||||
suffix: Optional[dict] = None
|
suffix: Optional[dict] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
response_format: Optional[AnyResponseFormat] = None
|
response_format: Optional[AnyResponseFormat] = None
|
||||||
@@ -405,6 +406,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
stream_options: Optional[StreamOptions] = None
|
stream_options: Optional[StreamOptions] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
@@ -27,34 +27,56 @@ if current_platform.is_gcu():
|
|||||||
|
|
||||||
def top_p_sampling(
|
def top_p_sampling(
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
ps: paddle.Tensor,
|
top_p: paddle.Tensor,
|
||||||
|
top_k: Optional[paddle.Tensor] = None,
|
||||||
threshold: Optional[paddle.Tensor] = None,
|
threshold: Optional[paddle.Tensor] = None,
|
||||||
topp_seed: Optional[paddle.Tensor] = None,
|
topp_seed: Optional[paddle.Tensor] = None,
|
||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
k: int = 0,
|
k: int = 0,
|
||||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||||
|
order: Literal['top_k_first', 'joint'] = "top_k_first",
|
||||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
"""
|
"""
|
||||||
top_p_sampling
|
x(Tensor): An input 2-D Tensor with type float32, float16 and bfloat16.
|
||||||
|
top_p(Tensor): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||||
|
used to specify the top_p corresponding to each query.
|
||||||
|
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||||
|
used to specify the top_k corresponding to each query.
|
||||||
|
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||||
|
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||||
|
used to avoid sampling low score tokens.
|
||||||
|
topp_seed(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||||
|
used to specify the random seed for each query.
|
||||||
|
seed(int, optional): the random seed. Default is -1,
|
||||||
|
k(int): the number of top_k scores/ids to be returned. Default is 0.
|
||||||
|
Only used when FD_SAMPLING_CLASS is `air`.
|
||||||
|
mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value.
|
||||||
|
If the mode is `non-truncated`, it will not be truncated. Default is `truncated`.
|
||||||
|
Only used when FD_SAMPLING_CLASS is `air` or `base`.
|
||||||
|
order(str): The order of applying top-k and top-p sampling, should be either `top_k_first` or `joint`.
|
||||||
|
If `top_k_first`, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||||
|
If `joint`, we apply top-k and top-p filter simultaneously in each round. Default is `top_k_first`.
|
||||||
|
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
||||||
if top_p_class == "air":
|
if top_p_class == "air":
|
||||||
_, ids = air_top_p_sampling(x,
|
_, ids = air_top_p_sampling(x,
|
||||||
ps,
|
top_p,
|
||||||
threshold,
|
threshold,
|
||||||
topp_seed,
|
topp_seed,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
k=k,
|
k=k,
|
||||||
mode=mode)
|
mode=mode)
|
||||||
elif top_p_class == "rejection":
|
elif top_p_class == "rejection":
|
||||||
ids = rejection_top_p_sampling(x, ps, seed)
|
ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
|
||||||
_ = None
|
_ = None
|
||||||
else:
|
else:
|
||||||
if current_platform.is_gcu():
|
if current_platform.is_gcu():
|
||||||
_, ids = gcu_top_p_sampling(x, ps)
|
_, ids = gcu_top_p_sampling(x, top_p)
|
||||||
else:
|
else:
|
||||||
_, ids = paddle.tensor.top_p_sampling(x,
|
_, ids = paddle.tensor.top_p_sampling(x,
|
||||||
ps,
|
top_p,
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
topp_seed=topp_seed,
|
topp_seed=topp_seed,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@@ -65,7 +87,7 @@ def top_p_sampling(
|
|||||||
|
|
||||||
def air_top_p_sampling(
|
def air_top_p_sampling(
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
ps: paddle.Tensor,
|
top_p: paddle.Tensor,
|
||||||
threshold: Optional[paddle.Tensor] = None,
|
threshold: Optional[paddle.Tensor] = None,
|
||||||
topp_seed: Optional[paddle.Tensor] = None,
|
topp_seed: Optional[paddle.Tensor] = None,
|
||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
@@ -77,7 +99,7 @@ def air_top_p_sampling(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
|
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
|
||||||
out, ids = air_top_p_sampling(x, ps, threshold, topp_seed, seed, k,
|
out, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed, k,
|
||||||
mode)
|
mode)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError("Cannot import air_top_p_sampling op.")
|
raise RuntimeError("Cannot import air_top_p_sampling op.")
|
||||||
@@ -86,19 +108,46 @@ def air_top_p_sampling(
|
|||||||
|
|
||||||
def rejection_top_p_sampling(
|
def rejection_top_p_sampling(
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
ps: paddle.Tensor,
|
top_p: paddle.Tensor,
|
||||||
|
top_k: Optional[paddle.Tensor] = None,
|
||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
|
order: Literal['top_k_first', 'joint'] = "top_k_first",
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
rejection_top_p_sampling
|
rejection_top_p_sampling
|
||||||
"""
|
"""
|
||||||
|
assert top_p is not None, "Top_p should not be none when FD_SAMPLING_CLASS is rejection"
|
||||||
try:
|
try:
|
||||||
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
ids = rejection_top_p_sampling(
|
rejection_top_p_sampling, top_k_renorm_probs)
|
||||||
x,
|
|
||||||
ps,
|
if top_k is None:
|
||||||
seed,
|
ids = rejection_top_p_sampling(
|
||||||
)
|
x,
|
||||||
|
top_p,
|
||||||
|
None,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
elif top_k is not None and top_p is not None:
|
||||||
|
if order == "top_k_first":
|
||||||
|
renorm_probs = top_k_renorm_probs(x, top_k)
|
||||||
|
ids = rejection_top_p_sampling(
|
||||||
|
renorm_probs,
|
||||||
|
top_p,
|
||||||
|
None,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ids = rejection_top_p_sampling(
|
||||||
|
x,
|
||||||
|
top_p,
|
||||||
|
top_k,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Top_p cannot be none."
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
||||||
return ids
|
return ids
|
||||||
|
@@ -214,7 +214,7 @@ class Sampler(nn.Layer):
|
|||||||
|
|
||||||
probs = F.softmax(logits)
|
probs = F.softmax(logits)
|
||||||
|
|
||||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||||
|
|
||||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
@@ -367,5 +367,5 @@ class MTPSampler(nn.Layer):
|
|||||||
)
|
)
|
||||||
probs = F.softmax(logits)
|
probs = F.softmax(logits)
|
||||||
|
|
||||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
@@ -154,12 +154,29 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
-1].disaggregate_info["role"] == "prefill":
|
-1].disaggregate_info["role"] == "prefill":
|
||||||
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
||||||
|
|
||||||
|
top_k_reqs = []
|
||||||
|
top_p_reqs = []
|
||||||
|
max_num_seqs = self.parallel_config.max_num_seqs
|
||||||
|
top_p_buffer = paddle.full([max_num_seqs, 1],
|
||||||
|
self.model_config.top_p,
|
||||||
|
dtype='float32')
|
||||||
|
top_k_buffer = paddle.full([max_num_seqs, 1],
|
||||||
|
0,
|
||||||
|
dtype='int64')
|
||||||
|
|
||||||
req_len = len(req_dicts)
|
req_len = len(req_dicts)
|
||||||
for i in range(req_len):
|
for i in range(req_len):
|
||||||
request = req_dicts[i]
|
request = req_dicts[i]
|
||||||
idx = request.idx
|
idx = request.idx
|
||||||
length = len(request.prompt_token_ids)
|
length = len(request.prompt_token_ids)
|
||||||
|
|
||||||
|
if sampling_params := request.sampling_params:
|
||||||
|
if sampling_params.top_p < 1:
|
||||||
|
top_p_reqs.append(idx)
|
||||||
|
top_k = sampling_params.top_k
|
||||||
|
if top_k > 0:
|
||||||
|
top_k_reqs.append(idx)
|
||||||
|
|
||||||
prefill_tokens = []
|
prefill_tokens = []
|
||||||
if (request.guided_json is not None
|
if (request.guided_json is not None
|
||||||
or request.guided_regex is not None
|
or request.guided_regex is not None
|
||||||
@@ -234,8 +251,8 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||||
self.share_inputs["eos_token_id"][:] = np.array(
|
self.share_inputs["eos_token_id"][:] = np.array(
|
||||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||||
|
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
|
||||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
|
||||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||||
"temperature", 0.95)
|
"temperature", 0.95)
|
||||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||||
@@ -286,6 +303,16 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
if self.speculative_method in ["mtp"]:
|
if self.speculative_method in ["mtp"]:
|
||||||
self.proposer.insert_prefill_inputs(req_dicts)
|
self.proposer.insert_prefill_inputs(req_dicts)
|
||||||
|
|
||||||
|
if len(top_k_reqs) == 0:
|
||||||
|
self.share_inputs["top_k"] = None
|
||||||
|
else:
|
||||||
|
self.share_inputs["top_k"] = top_k_buffer
|
||||||
|
|
||||||
|
if len(top_p_reqs) == 0:
|
||||||
|
self.share_inputs["top_p"] = None
|
||||||
|
else:
|
||||||
|
self.share_inputs["top_p"] = top_p_buffer
|
||||||
|
|
||||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||||
expected_decode_len: int):
|
expected_decode_len: int):
|
||||||
""" Set dummy prefill inputs to share_inputs """
|
""" Set dummy prefill inputs to share_inputs """
|
||||||
@@ -340,8 +367,11 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["eos_token_id"] = paddle.full(
|
self.share_inputs["eos_token_id"] = paddle.full(
|
||||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||||
self.model_config.top_p,
|
self.model_config.top_p,
|
||||||
dtype='float32')
|
dtype='float32')
|
||||||
|
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||||
|
0,
|
||||||
|
dtype='int64')
|
||||||
self.share_inputs["temperature"] = paddle.full(
|
self.share_inputs["temperature"] = paddle.full(
|
||||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||||
self.share_inputs["penalty_score"] = paddle.full(
|
self.share_inputs["penalty_score"] = paddle.full(
|
||||||
@@ -563,6 +593,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self.sampling_metadata = SamplingMetadata(
|
self.sampling_metadata = SamplingMetadata(
|
||||||
temperature=self.share_inputs["temperature"],
|
temperature=self.share_inputs["temperature"],
|
||||||
top_p=self.share_inputs["top_p"],
|
top_p=self.share_inputs["top_p"],
|
||||||
|
top_k=self.share_inputs["top_k"],
|
||||||
step_idx=self.share_inputs["step_idx"],
|
step_idx=self.share_inputs["step_idx"],
|
||||||
pre_token_ids=self.share_inputs["pre_ids"],
|
pre_token_ids=self.share_inputs["pre_ids"],
|
||||||
frequency_penalties=self.share_inputs["frequency_score"],
|
frequency_penalties=self.share_inputs["frequency_score"],
|
||||||
|
@@ -161,6 +161,15 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
-1].disaggregate_info["role"] == "prefill":
|
-1].disaggregate_info["role"] == "prefill":
|
||||||
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
||||||
|
|
||||||
|
top_k_reqs = []
|
||||||
|
top_p_reqs = []
|
||||||
|
max_num_seqs = self.parallel_config.max_num_seqs
|
||||||
|
top_p_buffer = paddle.full([max_num_seqs, 1],
|
||||||
|
self.model_config.top_p,
|
||||||
|
dtype='float32')
|
||||||
|
top_k_buffer = paddle.full([max_num_seqs, 1],
|
||||||
|
0,
|
||||||
|
dtype='int64')
|
||||||
req_len = len(req_dicts)
|
req_len = len(req_dicts)
|
||||||
for i in range(req_len):
|
for i in range(req_len):
|
||||||
request = req_dicts[i]
|
request = req_dicts[i]
|
||||||
@@ -168,6 +177,13 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
length = len(request.prompt_token_ids)
|
length = len(request.prompt_token_ids)
|
||||||
assert length > 0, "The prompt requested must not be empty."
|
assert length > 0, "The prompt requested must not be empty."
|
||||||
|
|
||||||
|
if sampling_params := request.sampling_params:
|
||||||
|
if sampling_params.top_p < 1:
|
||||||
|
top_p_reqs.append(idx)
|
||||||
|
top_k = sampling_params.top_k
|
||||||
|
if top_k > 0:
|
||||||
|
top_k_reqs.append(idx)
|
||||||
|
|
||||||
prefill_tokens = []
|
prefill_tokens = []
|
||||||
if (request.guided_json is not None
|
if (request.guided_json is not None
|
||||||
or request.guided_regex is not None
|
or request.guided_regex is not None
|
||||||
@@ -242,8 +258,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||||
self.share_inputs["eos_token_id"][:] = np.array(
|
self.share_inputs["eos_token_id"][:] = np.array(
|
||||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||||
|
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
|
||||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
|
||||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||||
"temperature", 0.95)
|
"temperature", 0.95)
|
||||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||||
@@ -294,6 +310,16 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
if self.speculative_method in ["mtp"]:
|
if self.speculative_method in ["mtp"]:
|
||||||
self.proposer.insert_prefill_inputs(req_dicts)
|
self.proposer.insert_prefill_inputs(req_dicts)
|
||||||
|
|
||||||
|
if len(top_k_reqs) == 0:
|
||||||
|
self.share_inputs["top_k"] = None
|
||||||
|
else:
|
||||||
|
self.share_inputs["top_k"] = top_k_buffer
|
||||||
|
|
||||||
|
if len(top_p_reqs) == 0:
|
||||||
|
self.share_inputs["top_p"] = None
|
||||||
|
else:
|
||||||
|
self.share_inputs["top_p"] = top_p_buffer
|
||||||
|
|
||||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||||
expected_decode_len: int):
|
expected_decode_len: int):
|
||||||
""" Set dummy prefill inputs to share_inputs """
|
""" Set dummy prefill inputs to share_inputs """
|
||||||
@@ -349,8 +375,11 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["eos_token_id"] = paddle.full(
|
self.share_inputs["eos_token_id"] = paddle.full(
|
||||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||||
self.model_config.top_p,
|
self.model_config.top_p,
|
||||||
dtype='float32')
|
dtype='float32')
|
||||||
|
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||||
|
0,
|
||||||
|
dtype='int64')
|
||||||
self.share_inputs["temperature"] = paddle.full(
|
self.share_inputs["temperature"] = paddle.full(
|
||||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||||
self.share_inputs["penalty_score"] = paddle.full(
|
self.share_inputs["penalty_score"] = paddle.full(
|
||||||
@@ -574,6 +603,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.sampling_metadata = SamplingMetadata(
|
self.sampling_metadata = SamplingMetadata(
|
||||||
temperature=self.share_inputs["temperature"],
|
temperature=self.share_inputs["temperature"],
|
||||||
top_p=self.share_inputs["top_p"],
|
top_p=self.share_inputs["top_p"],
|
||||||
|
top_k=self.share_inputs["top_k"],
|
||||||
step_idx=self.share_inputs["step_idx"],
|
step_idx=self.share_inputs["step_idx"],
|
||||||
pre_token_ids=self.share_inputs["pre_ids"],
|
pre_token_ids=self.share_inputs["pre_ids"],
|
||||||
frequency_penalties=self.share_inputs["frequency_score"],
|
frequency_penalties=self.share_inputs["frequency_score"],
|
||||||
|
@@ -29,9 +29,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import \
|
|||||||
AttentionBackend
|
AttentionBackend
|
||||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||||
from fastdeploy.model_executor.layers.sample.sampler import (Sampler,
|
from fastdeploy.model_executor.layers.sample.sampler import (
|
||||||
SpeculativeSampler
|
Sampler, SpeculativeSampler)
|
||||||
)
|
|
||||||
from fastdeploy.model_executor.model_loader import get_model_from_loader
|
from fastdeploy.model_executor.model_loader import get_model_from_loader
|
||||||
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
|
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
|
||||||
from fastdeploy.model_executor.pre_and_post_process import (post_process,
|
from fastdeploy.model_executor.pre_and_post_process import (post_process,
|
||||||
@@ -145,12 +144,29 @@ class IluvatarModelRunner(ModelRunnerBase):
|
|||||||
-1].disaggregate_info["role"] == "prefill":
|
-1].disaggregate_info["role"] == "prefill":
|
||||||
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
||||||
|
|
||||||
|
top_k_reqs = []
|
||||||
|
top_p_reqs = []
|
||||||
|
max_num_seqs = self.parallel_config.max_num_seqs
|
||||||
|
top_p_buffer = paddle.full([max_num_seqs, 1],
|
||||||
|
self.model_config.top_p,
|
||||||
|
dtype='float32')
|
||||||
|
top_k_buffer = paddle.full([max_num_seqs, 1],
|
||||||
|
0,
|
||||||
|
dtype='int64')
|
||||||
|
|
||||||
req_len = len(req_dicts)
|
req_len = len(req_dicts)
|
||||||
for i in range(req_len):
|
for i in range(req_len):
|
||||||
request = req_dicts[i]
|
request = req_dicts[i]
|
||||||
idx = request.idx
|
idx = request.idx
|
||||||
length = len(request.prompt_token_ids)
|
length = len(request.prompt_token_ids)
|
||||||
|
|
||||||
|
if sampling_params := request.sampling_params:
|
||||||
|
if sampling_params.top_p < 1:
|
||||||
|
top_p_reqs.append(idx)
|
||||||
|
top_k = sampling_params.top_k
|
||||||
|
if top_k > 0:
|
||||||
|
top_k_reqs.append(idx)
|
||||||
|
|
||||||
prefill_tokens = []
|
prefill_tokens = []
|
||||||
if (request.guided_json is not None
|
if (request.guided_json is not None
|
||||||
or request.guided_regex is not None
|
or request.guided_regex is not None
|
||||||
@@ -225,8 +241,8 @@ class IluvatarModelRunner(ModelRunnerBase):
|
|||||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||||
self.share_inputs["eos_token_id"][:] = np.array(
|
self.share_inputs["eos_token_id"][:] = np.array(
|
||||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||||
|
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
|
||||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
|
||||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||||
"temperature", 0.95)
|
"temperature", 0.95)
|
||||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||||
@@ -273,6 +289,15 @@ class IluvatarModelRunner(ModelRunnerBase):
|
|||||||
idx, request.get("logits_processor"), prefill_tokens)
|
idx, request.get("logits_processor"), prefill_tokens)
|
||||||
|
|
||||||
self.share_inputs["not_need_stop"][0] = True
|
self.share_inputs["not_need_stop"][0] = True
|
||||||
|
if len(top_k_reqs) == 0:
|
||||||
|
self.share_inputs["top_k"] = None
|
||||||
|
else:
|
||||||
|
self.share_inputs["top_k"] = top_k_buffer
|
||||||
|
|
||||||
|
if len(top_p_reqs) == 0:
|
||||||
|
self.share_inputs["top_p"] = None
|
||||||
|
else:
|
||||||
|
self.share_inputs["top_p"] = top_p_buffer
|
||||||
|
|
||||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||||
expected_decode_len: int):
|
expected_decode_len: int):
|
||||||
@@ -329,8 +354,11 @@ class IluvatarModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["eos_token_id"] = paddle.full(
|
self.share_inputs["eos_token_id"] = paddle.full(
|
||||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||||
self.model_config.top_p,
|
self.model_config.top_p,
|
||||||
dtype='float32')
|
dtype='float32')
|
||||||
|
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||||
|
0,
|
||||||
|
dtype='int64')
|
||||||
self.share_inputs["temperature"] = paddle.full(
|
self.share_inputs["temperature"] = paddle.full(
|
||||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||||
self.share_inputs["penalty_score"] = paddle.full(
|
self.share_inputs["penalty_score"] = paddle.full(
|
||||||
@@ -558,6 +586,7 @@ class IluvatarModelRunner(ModelRunnerBase):
|
|||||||
self.sampling_metadata = SamplingMetadata(
|
self.sampling_metadata = SamplingMetadata(
|
||||||
temperature=self.share_inputs["temperature"],
|
temperature=self.share_inputs["temperature"],
|
||||||
top_p=self.share_inputs["top_p"],
|
top_p=self.share_inputs["top_p"],
|
||||||
|
top_k=self.share_inputs["top_k"],
|
||||||
step_idx=self.share_inputs["step_idx"],
|
step_idx=self.share_inputs["step_idx"],
|
||||||
pre_token_ids=self.share_inputs["pre_ids"],
|
pre_token_ids=self.share_inputs["pre_ids"],
|
||||||
frequency_penalties=self.share_inputs["frequency_score"],
|
frequency_penalties=self.share_inputs["frequency_score"],
|
||||||
|
@@ -14,14 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
import paddle.distributed.fleet as fleet
|
import paddle.distributed.fleet as fleet
|
||||||
from fastdeploy.config import ModelConfig
|
|
||||||
|
|
||||||
|
from fastdeploy.config import ModelConfig
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger("worker", "worker.log")
|
logger = get_logger("worker", "worker.log")
|
||||||
|
@@ -282,11 +282,26 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
def process_prefill_inputs(self, req_dicts: List[Request]):
|
def process_prefill_inputs(self, req_dicts: List[Request]):
|
||||||
""" Process inputs for prefill tasks and update share_inputs buffer """
|
""" Process inputs for prefill tasks and update share_inputs buffer """
|
||||||
|
top_k_reqs = []
|
||||||
|
top_p_reqs = []
|
||||||
|
max_num_seqs = self.parallel_config.max_num_seqs
|
||||||
|
top_p_buffer = paddle.full([max_num_seqs, 1],
|
||||||
|
self.model_config.top_p,
|
||||||
|
dtype='float32')
|
||||||
|
top_k_buffer = paddle.full([max_num_seqs, 1],
|
||||||
|
0,
|
||||||
|
dtype='int64')
|
||||||
req_len = len(req_dicts)
|
req_len = len(req_dicts)
|
||||||
for i in range(req_len):
|
for i in range(req_len):
|
||||||
request = req_dicts[i]
|
request = req_dicts[i]
|
||||||
idx = request.idx
|
idx = request.idx
|
||||||
length = request.prompt_token_ids_len
|
length = request.prompt_token_ids_len
|
||||||
|
if sampling_params := request.sampling_params:
|
||||||
|
if sampling_params.top_p < 1:
|
||||||
|
top_p_reqs.append(idx)
|
||||||
|
top_k = sampling_params.top_k
|
||||||
|
if top_k > 0:
|
||||||
|
top_k_reqs.append(idx)
|
||||||
self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array(
|
self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array(
|
||||||
request.prompt_token_ids)
|
request.prompt_token_ids)
|
||||||
if len(request.eos_token_ids
|
if len(request.eos_token_ids
|
||||||
@@ -295,7 +310,8 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["eos_token_id"][:] = np.array(
|
self.share_inputs["eos_token_id"][:] = np.array(
|
||||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||||
self.share_inputs["pre_ids"][idx:idx + 1] = -1
|
self.share_inputs["pre_ids"][idx:idx + 1] = -1
|
||||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
|
||||||
|
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
|
||||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||||
"temperature", 0.95)
|
"temperature", 0.95)
|
||||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||||
@@ -344,6 +360,15 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
request.get("stop_token_ids"), dtype="int64")
|
request.get("stop_token_ids"), dtype="int64")
|
||||||
|
|
||||||
self.share_inputs["not_need_stop"][0] = True
|
self.share_inputs["not_need_stop"][0] = True
|
||||||
|
if len(top_k_reqs) == 0:
|
||||||
|
self.share_inputs["top_k"] = None
|
||||||
|
else:
|
||||||
|
self.share_inputs["top_k"] = top_k_buffer
|
||||||
|
|
||||||
|
if len(top_p_reqs) == 0:
|
||||||
|
self.share_inputs["top_p"] = None
|
||||||
|
else:
|
||||||
|
self.share_inputs["top_p"] = top_p_buffer
|
||||||
|
|
||||||
def _init_share_inputs(self, max_num_seqs: int):
|
def _init_share_inputs(self, max_num_seqs: int):
|
||||||
"""Initialize all share buffers for model inputs.
|
"""Initialize all share buffers for model inputs.
|
||||||
@@ -363,8 +388,11 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["eos_token_id"] = paddle.full(
|
self.share_inputs["eos_token_id"] = paddle.full(
|
||||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||||
self.model_config.top_p,
|
self.model_config.top_p,
|
||||||
dtype='float32')
|
dtype='float32')
|
||||||
|
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||||
|
0,
|
||||||
|
dtype='int64')
|
||||||
self.share_inputs["temperature"] = paddle.full(
|
self.share_inputs["temperature"] = paddle.full(
|
||||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||||
self.share_inputs["penalty_score"] = paddle.full(
|
self.share_inputs["penalty_score"] = paddle.full(
|
||||||
@@ -514,6 +542,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.sampling_metadata = SamplingMetadata(
|
self.sampling_metadata = SamplingMetadata(
|
||||||
temperature=self.share_inputs["temperature"],
|
temperature=self.share_inputs["temperature"],
|
||||||
top_p=self.share_inputs["top_p"],
|
top_p=self.share_inputs["top_p"],
|
||||||
|
top_k=self.share_inputs["top_k"],
|
||||||
step_idx=self.share_inputs["step_idx"],
|
step_idx=self.share_inputs["step_idx"],
|
||||||
pre_token_ids=self.share_inputs["pre_ids"],
|
pre_token_ids=self.share_inputs["pre_ids"],
|
||||||
frequency_penalties=self.share_inputs["frequency_score"],
|
frequency_penalties=self.share_inputs["frequency_score"],
|
||||||
|
Reference in New Issue
Block a user