[Feature] support top_k_top_p sampling (#2753)

* support top_k_top_p sampling

* fix

* add api param

* add api para

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Sunny-bot1
2025-07-10 11:58:58 +08:00
committed by GitHub
parent b0f525955c
commit e45050cae3
15 changed files with 501 additions and 53 deletions

View File

@@ -18,6 +18,7 @@
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs, std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
const paddle::Tensor &top_p, const paddle::Tensor &top_p,
const paddle::optional<paddle::Tensor> &top_k,
int seed) { int seed) {
std::vector<int64_t> probs_shape = probs.shape(); std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0]; unsigned int batch_size = probs_shape[0];
@@ -40,10 +41,18 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
cudaError_t status; cudaError_t status;
status = sampling::TopKTopPSamplingFromProb<float, int64_t>( if (top_k) {
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(), status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
batch_size, top_p.data<float>(), vocab_size, const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
true, philox_seed, philox_offset, cu_stream); batch_size, top_p.data<float>(), top_k.get().data<int64_t>(), vocab_size,
true, philox_seed, philox_offset, cu_stream);
}
else {
status = sampling::TopPSamplingFromProb<float, int64_t>(
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
batch_size, top_p.data<float>(), vocab_size,
true, philox_seed, philox_offset, cu_stream);
}
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status))); std::string(cudaGetErrorString(status)));
@@ -53,19 +62,21 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
std::vector<std::vector<int64_t>> std::vector<std::vector<int64_t>>
TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape, TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_p_shape) { const std::vector<int64_t> &top_p_shape,
const paddle::optional<std::vector<int64_t>> &top_k_shape) {
int64_t bs = probs_shape[0]; int64_t bs = probs_shape[0];
return {{bs, 1}}; return {{bs, 1}};
} }
std::vector<paddle::DataType> std::vector<paddle::DataType>
TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype, TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype,
const paddle::DataType &top_p_shape) { const paddle::DataType &top_p_dtype,
const paddle::optional<paddle::DataType> &top_k_dtype) {
return {paddle::DataType::INT64}; return {paddle::DataType::INT64};
} }
PD_BUILD_STATIC_OP(rejection_top_p_sampling) PD_BUILD_STATIC_OP(rejection_top_p_sampling)
.Inputs({"probs", "top_p"}) .Inputs({"probs", "top_p", paddle::Optional("top_k")})
.Outputs({"samples"}) .Outputs({"samples"})
.Attrs({"seed: int"}) .Attrs({"seed: int"})
.SetKernelFn(PD_KERNEL(TopPSamplingReject)) .SetKernelFn(PD_KERNEL(TopPSamplingReject))

View File

@@ -279,7 +279,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM, template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType> typename DType, typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr, __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
float* top_p_arr, IdType* top_k_arr,
uint32_t d, uint64_t philox_seed, uint32_t d, uint64_t philox_seed,
uint64_t philox_offset) { uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x; const uint32_t batch_size = gridDim.x;
@@ -287,7 +288,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, flo
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(philox_seed, bx, philox_offset, &state); curand_init(philox_seed, bx, philox_offset, &state);
const uint32_t row_idx = bx; const uint32_t row_idx = bx;
const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20; const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx]; const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
extern __shared__ __align__( extern __shared__ __align__(
@@ -479,7 +480,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
if (aggregate_gt_pivot_0 < top_p) { if (aggregate_gt_pivot_0 < top_p) {
// case 1: pivot_0 accepted // case 1: pivot_0 accepted
break; break;
} }
if (aggregate_gt_pivot_1 < top_p) { if (aggregate_gt_pivot_1 < top_p) {
// case 2: pivot_0 rejected, pivot_1 accepted // case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0; low = pivot_0;
@@ -497,6 +498,183 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
} }
} }
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
typename TempStorage>
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
TempStorage& temp_storage) {
const uint32_t tx = threadIdx.x;
vec_t<float, VEC_SIZE> in_data_vec;
float max_val = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
__syncthreads();
}
if (tx == 0) {
temp_storage.max_val = max_val;
}
__syncthreads();
return temp_storage.max_val;
}
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
struct RenormTempStorage {
union {
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_value_count;
} block_prim;
struct {
float max_val;
float min_val;
union {
struct {
float values[2];
};
struct {
int counts[2];
};
struct {
ValueCount<float> pairs[2];
};
} block_aggregate;
};
};
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
vec_t<float, VEC_SIZE> probs_vec;
if (k < d) {
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
temp_storage.max_val = 0;
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
double low = 0, high = max_val;
float min_gt_low, max_le_high;
float sum_low = 1;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
// loop invariant:
// - f(low) >= k, f(high) < k
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
// stopping condition: min_gt_low == max_le_high
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
do {
double pivot_0 = (high + 2 * low) / 3;
double pivot_1 = (2 * high + low) / 3;
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
min_gt_low = high;
max_le_high = low;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0_pair[j] = {
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1_pair[j] = {
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, probs_vec[j]);
}
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, probs_vec[j]);
}
}
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
__syncthreads();
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
__syncthreads();
}
min_gt_low =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
__syncthreads();
max_le_high =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
if (tx == 0) {
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
temp_storage.min_val = min_gt_low;
temp_storage.max_val = max_le_high;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0];
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1];
min_gt_low = temp_storage.min_val;
max_le_high = temp_storage.max_val;
if (aggregate_gt_pivot_1.count >= k) {
low = pivot_1;
sum_low = float(aggregate_gt_pivot_1.value);
} else if (aggregate_gt_pivot_0.count >= k) {
low = pivot_0;
high = min(pivot_1, max_le_high);
sum_low = float(aggregate_gt_pivot_0.value);
} else {
high = min(pivot_0, max_le_high);
}
} while (min_gt_low != max_le_high);
normalizer = ptx_rcp(max(sum_low, 1e-8));
pivot = low;
}
// normalize
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
}
template <typename T, typename IdType> template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T *probs, IdType *output, cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val, uint32_t batch_size, const T *top_p_val,
@@ -529,7 +707,7 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
template <typename T, typename IdType> template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output, cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val, uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
uint32_t d, bool deterministic, uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset, uint64_t philox_seed, uint64_t philox_offset,
cudaStream_t stream = 0) { cudaStream_t stream = 0) {
@@ -540,7 +718,7 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>); const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size); dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS); dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val, void* args[] = {&probs, &output, &top_p_val, &top_k_val,
&d, &philox_seed, &philox_offset}; &d, &philox_seed, &philox_offset};
DISPATCH_ALIGNED_VEC_SIZE( DISPATCH_ALIGNED_VEC_SIZE(
@@ -556,4 +734,26 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
}); });
} }
} // namespace sampling template <typename DType, typename IdType>
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t batch_size, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
});
}
} // namespace sampling

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/phi/backends/context_pool.h"
#include "sample_kernels/sampling.cuh"
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
const paddle::Tensor &top_k) {
std::vector<int64_t> probs_shape = probs.shape();
uint32_t batch_size = probs_shape[0];
uint32_t vocab_size = probs_shape[1];
auto cu_stream = probs.stream();
auto renorm_probs =
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
cudaError_t status;
status = sampling::TopKRenormProb<float>(
const_cast<float *>(probs.data<float>()),
renorm_probs.data<float>(),
const_cast<int64_t *>(top_k.data<int64_t>()),
batch_size, vocab_size, cu_stream);
PD_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " +
std::string(cudaGetErrorString(status)));
return {renorm_probs};
}
std::vector<std::vector<int64_t>>
TopKRenormInferShape(const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_k_shape) {
return {probs_shape};
}
std::vector<paddle::DataType>
TopKRenormInferDtype(const paddle::DataType &probs_dtype,
const paddle::DataType &top_k_shape) {
return {probs_dtype};
}
PD_BUILD_STATIC_OP(top_k_renorm_probs)
.Inputs({"probs", "top_k"})
.Outputs({"renorm_probs"})
.SetKernelFn(PD_KERNEL(TopKRenorm))
.SetInferShapeFn(PD_INFER_SHAPE(TopKRenormInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(TopKRenormInferDtype));

View File

@@ -282,6 +282,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/text_image_index_out.cu", "gpu_ops/text_image_index_out.cu",
"gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu", "gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu", "gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc.cu",

View File

@@ -178,6 +178,7 @@ for output in outputs:
* repetition_penalty(float): 直接对重复生成的token进行惩罚的系数>1时惩罚重复<1时鼓励重复 * repetition_penalty(float): 直接对重复生成的token进行惩罚的系数>1时惩罚重复<1时鼓励重复
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定 * temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
* top_p(float): 概率累积分布截断阈值仅考虑累计概率达到此阈值的最可能token集合 * top_p(float): 概率累积分布截断阈值仅考虑累计概率达到此阈值的最可能token集合
* top_k(int): 采样概率最高的的token数量考虑概率最高的k个token进行采样
* max_tokens(int): 限制模型生成的最大token数量包括输入和输出 * max_tokens(int): 限制模型生成的最大token数量包括输入和输出
* min_tokens(int): 强制模型生成的最少token数量避免过早结束 * min_tokens(int): 强制模型生成的最少token数量避免过早结束

View File

@@ -16,8 +16,8 @@
import json import json
import os import os
from datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from fastdeploy import envs from fastdeploy import envs

View File

@@ -52,6 +52,7 @@ class SamplingParams:
the model more random. Zero means greedy sampling. the model more random. Zero means greedy sampling.
top_p: Float that controls the cumulative probability of the top tokens top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in [0, 1]. Set to 1 to consider all tokens. to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
seed: Random seed to use for the generation. seed: Random seed to use for the generation.
stop: list of strings that stop the generation when they are generated. stop: list of strings that stop the generation when they are generated.
The returned output will not contain the stop strings. The returned output will not contain the stop strings.
@@ -81,7 +82,8 @@ class SamplingParams:
frequency_penalty: float = None frequency_penalty: float = None
repetition_penalty: float = None repetition_penalty: float = None
temperature: float = None temperature: float = None
top_p: float = None top_p: float = 1.0
top_k: int = 0
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
@@ -111,6 +113,7 @@ class SamplingParams:
repetition_penalty, repetition_penalty,
temperature, temperature,
top_p, top_p,
top_k,
seed=None, seed=None,
stop=None, stop=None,
stop_token_ids=None, stop_token_ids=None,
@@ -129,7 +132,8 @@ class SamplingParams:
repetition_penalty=repetition_penalty repetition_penalty=repetition_penalty
if repetition_penalty is not None else 1.0, if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0, temperature=temperature if temperature is not None else 1.0,
top_p=top_p if top_p is not None else 0.7, top_p=top_p if top_p is not None else 1.0,
top_k=top_k if top_k is not None else 0,
seed=seed, seed=seed,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,

View File

@@ -292,6 +292,7 @@ class CompletionRequest(BaseModel):
suffix: Optional[dict] = None suffix: Optional[dict] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
top_k: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None
response_format: Optional[AnyResponseFormat] = None response_format: Optional[AnyResponseFormat] = None
@@ -405,6 +406,7 @@ class ChatCompletionRequest(BaseModel):
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
top_k: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None
metadata: Optional[dict] = None metadata: Optional[dict] = None

View File

@@ -27,34 +27,56 @@ if current_platform.is_gcu():
def top_p_sampling( def top_p_sampling(
x: paddle.Tensor, x: paddle.Tensor,
ps: paddle.Tensor, top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,
threshold: Optional[paddle.Tensor] = None, threshold: Optional[paddle.Tensor] = None,
topp_seed: Optional[paddle.Tensor] = None, topp_seed: Optional[paddle.Tensor] = None,
seed: int = -1, seed: int = -1,
k: int = 0, k: int = 0,
mode: Literal['truncated', 'non-truncated'] = "truncated", mode: Literal['truncated', 'non-truncated'] = "truncated",
order: Literal['top_k_first', 'joint'] = "top_k_first",
) -> tuple[paddle.Tensor, paddle.Tensor]: ) -> tuple[paddle.Tensor, paddle.Tensor]:
""" """
top_p_sampling x(Tensor): An input 2-D Tensor with type float32, float16 and bfloat16.
top_p(Tensor): A 1-D Tensor with type float32, float16 and bfloat16,
used to specify the top_p corresponding to each query.
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
used to specify the top_k corresponding to each query.
Only used when FD_SAMPLING_CLASS is `rejection`.
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
used to avoid sampling low score tokens.
topp_seed(Tensor|None, optional): A 1-D Tensor with type int64,
used to specify the random seed for each query.
seed(int, optional): the random seed. Default is -1,
k(int): the number of top_k scores/ids to be returned. Default is 0.
Only used when FD_SAMPLING_CLASS is `air`.
mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value.
If the mode is `non-truncated`, it will not be truncated. Default is `truncated`.
Only used when FD_SAMPLING_CLASS is `air` or `base`.
order(str): The order of applying top-k and top-p sampling, should be either `top_k_first` or `joint`.
If `top_k_first`, we first apply top-k filter, then apply top-p sampling on the top-k results.
If `joint`, we apply top-k and top-p filter simultaneously in each round. Default is `top_k_first`.
Only used when FD_SAMPLING_CLASS is `rejection`.
""" """
top_p_class = envs.FD_SAMPLING_CLASS.lower() top_p_class = envs.FD_SAMPLING_CLASS.lower()
if top_p_class == "air": if top_p_class == "air":
_, ids = air_top_p_sampling(x, _, ids = air_top_p_sampling(x,
ps, top_p,
threshold, threshold,
topp_seed, topp_seed,
seed=seed, seed=seed,
k=k, k=k,
mode=mode) mode=mode)
elif top_p_class == "rejection": elif top_p_class == "rejection":
ids = rejection_top_p_sampling(x, ps, seed) ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
_ = None _ = None
else: else:
if current_platform.is_gcu(): if current_platform.is_gcu():
_, ids = gcu_top_p_sampling(x, ps) _, ids = gcu_top_p_sampling(x, top_p)
else: else:
_, ids = paddle.tensor.top_p_sampling(x, _, ids = paddle.tensor.top_p_sampling(x,
ps, top_p,
threshold=threshold, threshold=threshold,
topp_seed=topp_seed, topp_seed=topp_seed,
seed=seed, seed=seed,
@@ -65,7 +87,7 @@ def top_p_sampling(
def air_top_p_sampling( def air_top_p_sampling(
x: paddle.Tensor, x: paddle.Tensor,
ps: paddle.Tensor, top_p: paddle.Tensor,
threshold: Optional[paddle.Tensor] = None, threshold: Optional[paddle.Tensor] = None,
topp_seed: Optional[paddle.Tensor] = None, topp_seed: Optional[paddle.Tensor] = None,
seed: int = -1, seed: int = -1,
@@ -77,7 +99,7 @@ def air_top_p_sampling(
""" """
try: try:
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
out, ids = air_top_p_sampling(x, ps, threshold, topp_seed, seed, k, out, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed, k,
mode) mode)
except ImportError: except ImportError:
raise RuntimeError("Cannot import air_top_p_sampling op.") raise RuntimeError("Cannot import air_top_p_sampling op.")
@@ -86,19 +108,46 @@ def air_top_p_sampling(
def rejection_top_p_sampling( def rejection_top_p_sampling(
x: paddle.Tensor, x: paddle.Tensor,
ps: paddle.Tensor, top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,
seed: int = -1, seed: int = -1,
order: Literal['top_k_first', 'joint'] = "top_k_first",
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
rejection_top_p_sampling rejection_top_p_sampling
""" """
assert top_p is not None, "Top_p should not be none when FD_SAMPLING_CLASS is rejection"
try: try:
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling from fastdeploy.model_executor.ops.gpu import (
ids = rejection_top_p_sampling( rejection_top_p_sampling, top_k_renorm_probs)
x,
ps, if top_k is None:
seed, ids = rejection_top_p_sampling(
) x,
top_p,
None,
seed,
)
elif top_k is not None and top_p is not None:
if order == "top_k_first":
renorm_probs = top_k_renorm_probs(x, top_k)
ids = rejection_top_p_sampling(
renorm_probs,
top_p,
None,
seed,
)
else:
ids = rejection_top_p_sampling(
x,
top_p,
top_k,
seed,
)
else:
raise ValueError(
"Top_p cannot be none."
)
except ImportError: except ImportError:
raise RuntimeError("Cannot import rejection_top_p_sampling op.") raise RuntimeError("Cannot import rejection_top_p_sampling op.")
return ids return ids

View File

@@ -214,7 +214,7 @@ class Sampler(nn.Layer):
probs = F.softmax(logits) probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p) _, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
self.processor.update_output_tokens(next_tokens, skip_idx_list) self.processor.update_output_tokens(next_tokens, skip_idx_list)
return next_tokens return next_tokens
@@ -367,5 +367,5 @@ class MTPSampler(nn.Layer):
) )
probs = F.softmax(logits) probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p) _, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
return next_tokens return next_tokens

View File

@@ -154,12 +154,29 @@ class GCUModelRunner(ModelRunnerBase):
-1].disaggregate_info["role"] == "prefill": -1].disaggregate_info["role"] == "prefill":
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1" os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts) req_len = len(req_dicts)
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
length = len(request.prompt_token_ids) length = len(request.prompt_token_ids)
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
prefill_tokens = [] prefill_tokens = []
if (request.guided_json is not None if (request.guided_json is not None
or request.guided_regex is not None or request.guided_regex is not None
@@ -234,8 +251,8 @@ class GCUModelRunner(ModelRunnerBase):
request.eos_token_ids.append(request.eos_token_ids[0]) request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array( self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1) request.eos_token_ids, dtype="int64").reshape(-1, 1)
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get( self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95) "temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get( self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -286,6 +303,16 @@ class GCUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts) self.proposer.insert_prefill_inputs(req_dicts)
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int): expected_decode_len: int):
""" Set dummy prefill inputs to share_inputs """ """ Set dummy prefill inputs to share_inputs """
@@ -340,8 +367,11 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full( self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p, self.model_config.top_p,
dtype='float32') dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full( self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32') [max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full( self.share_inputs["penalty_score"] = paddle.full(
@@ -563,6 +593,7 @@ class GCUModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata( self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"], temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"], frequency_penalties=self.share_inputs["frequency_score"],

View File

@@ -161,6 +161,15 @@ class GPUModelRunner(ModelRunnerBase):
-1].disaggregate_info["role"] == "prefill": -1].disaggregate_info["role"] == "prefill":
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1" os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts) req_len = len(req_dicts)
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
@@ -168,6 +177,13 @@ class GPUModelRunner(ModelRunnerBase):
length = len(request.prompt_token_ids) length = len(request.prompt_token_ids)
assert length > 0, "The prompt requested must not be empty." assert length > 0, "The prompt requested must not be empty."
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
prefill_tokens = [] prefill_tokens = []
if (request.guided_json is not None if (request.guided_json is not None
or request.guided_regex is not None or request.guided_regex is not None
@@ -242,8 +258,8 @@ class GPUModelRunner(ModelRunnerBase):
request.eos_token_ids.append(request.eos_token_ids[0]) request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array( self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1) request.eos_token_ids, dtype="int64").reshape(-1, 1)
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get( self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95) "temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get( self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -294,6 +310,16 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts) self.proposer.insert_prefill_inputs(req_dicts)
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int): expected_decode_len: int):
""" Set dummy prefill inputs to share_inputs """ """ Set dummy prefill inputs to share_inputs """
@@ -349,8 +375,11 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full( self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p, self.model_config.top_p,
dtype='float32') dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full( self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32') [max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full( self.share_inputs["penalty_score"] = paddle.full(
@@ -574,6 +603,7 @@ class GPUModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata( self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"], temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"], frequency_penalties=self.share_inputs["frequency_score"],

View File

@@ -29,9 +29,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import \
AttentionBackend AttentionBackend
from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import (Sampler, from fastdeploy.model_executor.layers.sample.sampler import (
SpeculativeSampler Sampler, SpeculativeSampler)
)
from fastdeploy.model_executor.model_loader import get_model_from_loader from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
from fastdeploy.model_executor.pre_and_post_process import (post_process, from fastdeploy.model_executor.pre_and_post_process import (post_process,
@@ -145,12 +144,29 @@ class IluvatarModelRunner(ModelRunnerBase):
-1].disaggregate_info["role"] == "prefill": -1].disaggregate_info["role"] == "prefill":
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1" os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts) req_len = len(req_dicts)
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
length = len(request.prompt_token_ids) length = len(request.prompt_token_ids)
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
prefill_tokens = [] prefill_tokens = []
if (request.guided_json is not None if (request.guided_json is not None
or request.guided_regex is not None or request.guided_regex is not None
@@ -225,8 +241,8 @@ class IluvatarModelRunner(ModelRunnerBase):
request.eos_token_ids.append(request.eos_token_ids[0]) request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array( self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1) request.eos_token_ids, dtype="int64").reshape(-1, 1)
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get( self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95) "temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get( self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -273,6 +289,15 @@ class IluvatarModelRunner(ModelRunnerBase):
idx, request.get("logits_processor"), prefill_tokens) idx, request.get("logits_processor"), prefill_tokens)
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
expected_decode_len: int): expected_decode_len: int):
@@ -329,8 +354,11 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full( self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p, self.model_config.top_p,
dtype='float32') dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full( self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32') [max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full( self.share_inputs["penalty_score"] = paddle.full(
@@ -558,6 +586,7 @@ class IluvatarModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata( self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"], temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"], frequency_penalties=self.share_inputs["frequency_score"],

View File

@@ -14,14 +14,14 @@
# limitations under the License. # limitations under the License.
""" """
from abc import ABC, abstractmethod
import argparse import argparse
from abc import ABC, abstractmethod
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from fastdeploy.config import ModelConfig
from fastdeploy.config import ModelConfig
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
logger = get_logger("worker", "worker.log") logger = get_logger("worker", "worker.log")

View File

@@ -282,11 +282,26 @@ class XPUModelRunner(ModelRunnerBase):
def process_prefill_inputs(self, req_dicts: List[Request]): def process_prefill_inputs(self, req_dicts: List[Request]):
""" Process inputs for prefill tasks and update share_inputs buffer """ """ Process inputs for prefill tasks and update share_inputs buffer """
top_k_reqs = []
top_p_reqs = []
max_num_seqs = self.parallel_config.max_num_seqs
top_p_buffer = paddle.full([max_num_seqs, 1],
self.model_config.top_p,
dtype='float32')
top_k_buffer = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
req_len = len(req_dicts) req_len = len(req_dicts)
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
length = request.prompt_token_ids_len length = request.prompt_token_ids_len
if sampling_params := request.sampling_params:
if sampling_params.top_p < 1:
top_p_reqs.append(idx)
top_k = sampling_params.top_k
if top_k > 0:
top_k_reqs.append(idx)
self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array( self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array(
request.prompt_token_ids) request.prompt_token_ids)
if len(request.eos_token_ids if len(request.eos_token_ids
@@ -295,7 +310,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"][:] = np.array( self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1) request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["pre_ids"][idx:idx + 1] = -1 self.share_inputs["pre_ids"][idx:idx + 1] = -1
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get( self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95) "temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get( self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -344,6 +360,15 @@ class XPUModelRunner(ModelRunnerBase):
request.get("stop_token_ids"), dtype="int64") request.get("stop_token_ids"), dtype="int64")
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
if len(top_k_reqs) == 0:
self.share_inputs["top_k"] = None
else:
self.share_inputs["top_k"] = top_k_buffer
if len(top_p_reqs) == 0:
self.share_inputs["top_p"] = None
else:
self.share_inputs["top_p"] = top_p_buffer
def _init_share_inputs(self, max_num_seqs: int): def _init_share_inputs(self, max_num_seqs: int):
"""Initialize all share buffers for model inputs. """Initialize all share buffers for model inputs.
@@ -363,8 +388,11 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["eos_token_id"] = paddle.full( self.share_inputs["eos_token_id"] = paddle.full(
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
self.model_config.top_p, self.model_config.top_p,
dtype='float32') dtype='float32')
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
0,
dtype='int64')
self.share_inputs["temperature"] = paddle.full( self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype='float32') [max_num_seqs, 1], self.model_config.temperature, dtype='float32')
self.share_inputs["penalty_score"] = paddle.full( self.share_inputs["penalty_score"] = paddle.full(
@@ -514,6 +542,7 @@ class XPUModelRunner(ModelRunnerBase):
self.sampling_metadata = SamplingMetadata( self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"], temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"], frequency_penalties=self.share_inputs["frequency_score"],