diff --git a/custom_ops/gpu_ops/sample_kernels/rejection_top_p_sampling.cu b/custom_ops/gpu_ops/sample_kernels/rejection_top_p_sampling.cu index 598297be5..238c819eb 100644 --- a/custom_ops/gpu_ops/sample_kernels/rejection_top_p_sampling.cu +++ b/custom_ops/gpu_ops/sample_kernels/rejection_top_p_sampling.cu @@ -18,6 +18,7 @@ std::vector TopPSamplingReject(const paddle::Tensor &probs, const paddle::Tensor &top_p, + const paddle::optional &top_k, int seed) { std::vector probs_shape = probs.shape(); unsigned int batch_size = probs_shape[0]; @@ -40,10 +41,18 @@ std::vector TopPSamplingReject(const paddle::Tensor &probs, cudaError_t status; - status = sampling::TopKTopPSamplingFromProb( - const_cast(probs.data()), samples.data(), - batch_size, top_p.data(), vocab_size, - true, philox_seed, philox_offset, cu_stream); + if (top_k) { + status = sampling::TopKTopPSamplingFromProb( + const_cast(probs.data()), samples.data(), + batch_size, top_p.data(), top_k.get().data(), vocab_size, + true, philox_seed, philox_offset, cu_stream); + } + else { + status = sampling::TopPSamplingFromProb( + const_cast(probs.data()), samples.data(), + batch_size, top_p.data(), vocab_size, + true, philox_seed, philox_offset, cu_stream); + } PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); @@ -53,19 +62,21 @@ std::vector TopPSamplingReject(const paddle::Tensor &probs, std::vector> TopPSamplingRejectInferShape(const std::vector &probs_shape, - const std::vector &top_p_shape) { + const std::vector &top_p_shape, + const paddle::optional> &top_k_shape) { int64_t bs = probs_shape[0]; return {{bs, 1}}; } std::vector TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype, - const paddle::DataType &top_p_shape) { + const paddle::DataType &top_p_dtype, + const paddle::optional &top_k_dtype) { return {paddle::DataType::INT64}; } PD_BUILD_STATIC_OP(rejection_top_p_sampling) - .Inputs({"probs", "top_p"}) + .Inputs({"probs", "top_p", paddle::Optional("top_k")}) .Outputs({"samples"}) .Attrs({"seed: int"}) .SetKernelFn(PD_KERNEL(TopPSamplingReject)) diff --git a/custom_ops/gpu_ops/sample_kernels/sampling.cuh b/custom_ops/gpu_ops/sample_kernels/sampling.cuh index eb5f6f1b8..7102c73d6 100644 --- a/custom_ops/gpu_ops/sample_kernels/sampling.cuh +++ b/custom_ops/gpu_ops/sample_kernels/sampling.cuh @@ -279,7 +279,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb( template -__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr, +__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, + float* top_p_arr, IdType* top_k_arr, uint32_t d, uint64_t philox_seed, uint64_t philox_offset) { const uint32_t batch_size = gridDim.x; @@ -287,7 +288,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, flo curandStatePhilox4_32_10_t state; curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = bx; - const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20; + const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx]; const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx]; extern __shared__ __align__( @@ -479,7 +480,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, if (aggregate_gt_pivot_0 < top_p) { // case 1: pivot_0 accepted break; - } + } if (aggregate_gt_pivot_1 < top_p) { // case 2: pivot_0 rejected, pivot_1 accepted low = pivot_0; @@ -497,6 +498,183 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, } } +template +__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, + TempStorage& temp_storage) { + const uint32_t tx = threadIdx.x; + vec_t in_data_vec; + + float max_val = 0; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + in_data_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + float in_data_[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + in_data_[j] = in_data_vec[j]; + } + max_val = max( + max_val, BlockReduce(temp_storage.block_prim.reduce) + .Reduce(in_data_, cub::Max())); + __syncthreads(); + } + if (tx == 0) { + temp_storage.max_val = max_val; + } + __syncthreads(); + return temp_storage.max_val; +} + +template +struct RenormTempStorage { + union { + typename BlockReduce::TempStorage reduce; + typename BlockReduce::TempStorage reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_value_count; + } block_prim; + struct { + float max_val; + float min_val; + union { + struct { + float values[2]; + }; + struct { + int counts[2]; + }; + struct { + ValueCount pairs[2]; + }; + } block_aggregate; + }; +}; + +template +__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx]; + double pivot = -cuda::std::numeric_limits::infinity(), normalizer = 1; + vec_t probs_vec; + if (k < d) { + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem_renorm[]; + auto& temp_storage = + reinterpret_cast&>(smem_renorm); + temp_storage.max_val = 0; + + float max_val = GetMaxValue>( + probs, row_idx, d, temp_storage); + + double low = 0, high = max_val; + float min_gt_low, max_le_high; + float sum_low = 1; + // f(x) = len(nonzero(probs > x)), f(x) is non-increasing + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= k, f(high) < k + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition: min_gt_low == max_le_high + // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k + do { + double pivot_0 = (high + 2 * low) / 3; + double pivot_1 = (2 * high + low) / 3; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + min_gt_low = high; + max_le_high = low; +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + ValueCount probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot_0_pair[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1_pair[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + + if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + min_gt_low = min(min_gt_low, probs_vec[j]); + } + if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + max_le_high = max(max_le_high, probs_vec[j]); + } + } + + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0_pair); + __syncthreads(); + + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1_pair); + __syncthreads(); + } + min_gt_low = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, cub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, cub::Max()); + if (tx == 0) { + temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0; + temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1; + temp_storage.min_val = min_gt_low; + temp_storage.max_val = max_le_high; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0]; + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1]; + min_gt_low = temp_storage.min_val; + max_le_high = temp_storage.max_val; + + if (aggregate_gt_pivot_1.count >= k) { + low = pivot_1; + sum_low = float(aggregate_gt_pivot_1.value); + } else if (aggregate_gt_pivot_0.count >= k) { + low = pivot_0; + high = min(pivot_1, max_le_high); + sum_low = float(aggregate_gt_pivot_0.value); + } else { + high = min(pivot_0, max_le_high); + } + } while (min_gt_low != max_le_high); + + normalizer = ptx_rcp(max(sum_low, 1e-8)); + pivot = low; + } + + // normalize +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0; + } + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + } +} + template cudaError_t TopPSamplingFromProb(T *probs, IdType *output, uint32_t batch_size, const T *top_p_val, @@ -529,7 +707,7 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output, template cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output, - uint32_t batch_size, const T *top_p_val, + uint32_t batch_size, const T *top_p_val, const IdType *top_k_val, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, cudaStream_t stream = 0) { @@ -540,7 +718,7 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output, const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &top_p_val, + void* args[] = {&probs, &output, &top_p_val, &top_k_val, &d, &philox_seed, &philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( @@ -556,4 +734,26 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output, }); } -} // namespace sampling \ No newline at end of file +template +cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, + uint32_t batch_size, uint32_t d, + cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &top_k_arr, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKRenormProbKernel; + CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; + }); +} + +} // namespace sampling diff --git a/custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu b/custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu new file mode 100644 index 000000000..ea4ab0dbb --- /dev/null +++ b/custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu @@ -0,0 +1,61 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/phi/backends/context_pool.h" +#include "sample_kernels/sampling.cuh" + +std::vector TopKRenorm(const paddle::Tensor &probs, + const paddle::Tensor &top_k) { + std::vector probs_shape = probs.shape(); + uint32_t batch_size = probs_shape[0]; + uint32_t vocab_size = probs_shape[1]; + auto cu_stream = probs.stream(); + + auto renorm_probs = + GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place()); + + cudaError_t status; + + + status = sampling::TopKRenormProb( + const_cast(probs.data()), + renorm_probs.data(), + const_cast(top_k.data()), + batch_size, vocab_size, cu_stream); + + PD_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " + + std::string(cudaGetErrorString(status))); + + return {renorm_probs}; +} + +std::vector> +TopKRenormInferShape(const std::vector &probs_shape, + const std::vector &top_k_shape) { + return {probs_shape}; +} + +std::vector +TopKRenormInferDtype(const paddle::DataType &probs_dtype, + const paddle::DataType &top_k_shape) { + return {probs_dtype}; +} + +PD_BUILD_STATIC_OP(top_k_renorm_probs) + .Inputs({"probs", "top_k"}) + .Outputs({"renorm_probs"}) + .SetKernelFn(PD_KERNEL(TopKRenorm)) + .SetInferShapeFn(PD_INFER_SHAPE(TopKRenormInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(TopKRenormInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index e3fe1a054..c002beeb6 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -282,6 +282,7 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/text_image_index_out.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/sample_kernels/rejection_top_p_sampling.cu", + "gpu_ops/sample_kernels/top_k_renorm_probs.cu", "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", diff --git a/docs/zh/offline_inference.md b/docs/zh/offline_inference.md index 1e3f3e466..382e65740 100644 --- a/docs/zh/offline_inference.md +++ b/docs/zh/offline_inference.md @@ -178,6 +178,7 @@ for output in outputs: * repetition_penalty(float): 直接对重复生成的token进行惩罚的系数(>1时惩罚重复,<1时鼓励重复) * temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定 * top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合 +* top_k(int): 采样概率最高的的token数量,考虑概率最高的k个token进行采样 * max_tokens(int): 限制模型生成的最大token数量(包括输入和输出) * min_tokens(int): 强制模型生成的最少token数量,避免过早结束 diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index ca76fa9c5..ab421fdef 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -16,8 +16,8 @@ import json import os -from datetime import datetime from dataclasses import dataclass +from datetime import datetime from typing import Any, Dict, List, Literal, Optional from fastdeploy import envs diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 0f60cf36b..00a6aeb13 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -52,6 +52,7 @@ class SamplingParams: the model more random. Zero means greedy sampling. top_p: Float that controls the cumulative probability of the top tokens to consider. Must be in [0, 1]. Set to 1 to consider all tokens. + top_k: Int that controls the number of top tokens to consider. Must be a positive integer. seed: Random seed to use for the generation. stop: list of strings that stop the generation when they are generated. The returned output will not contain the stop strings. @@ -81,7 +82,8 @@ class SamplingParams: frequency_penalty: float = None repetition_penalty: float = None temperature: float = None - top_p: float = None + top_p: float = 1.0 + top_k: int = 0 seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None @@ -111,6 +113,7 @@ class SamplingParams: repetition_penalty, temperature, top_p, + top_k, seed=None, stop=None, stop_token_ids=None, @@ -129,7 +132,8 @@ class SamplingParams: repetition_penalty=repetition_penalty if repetition_penalty is not None else 1.0, temperature=temperature if temperature is not None else 1.0, - top_p=top_p if top_p is not None else 0.7, + top_p=top_p if top_p is not None else 1.0, + top_k=top_k if top_k is not None else 0, seed=seed, stop=stop, stop_token_ids=stop_token_ids, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 845786704..76a575152 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -292,6 +292,7 @@ class CompletionRequest(BaseModel): suffix: Optional[dict] = None temperature: Optional[float] = None top_p: Optional[float] = None + top_k: Optional[int] = None user: Optional[str] = None response_format: Optional[AnyResponseFormat] = None @@ -405,6 +406,7 @@ class ChatCompletionRequest(BaseModel): stream_options: Optional[StreamOptions] = None temperature: Optional[float] = None top_p: Optional[float] = None + top_k: Optional[int] = None user: Optional[str] = None metadata: Optional[dict] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py index eeebb610b..08635f810 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py @@ -27,34 +27,56 @@ if current_platform.is_gcu(): def top_p_sampling( x: paddle.Tensor, - ps: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, threshold: Optional[paddle.Tensor] = None, topp_seed: Optional[paddle.Tensor] = None, seed: int = -1, k: int = 0, mode: Literal['truncated', 'non-truncated'] = "truncated", + order: Literal['top_k_first', 'joint'] = "top_k_first", ) -> tuple[paddle.Tensor, paddle.Tensor]: """ - top_p_sampling + x(Tensor): An input 2-D Tensor with type float32, float16 and bfloat16. + top_p(Tensor): A 1-D Tensor with type float32, float16 and bfloat16, + used to specify the top_p corresponding to each query. + top_k(Tensor|None, optional): A 1-D Tensor with type int64, + used to specify the top_k corresponding to each query. + Only used when FD_SAMPLING_CLASS is `rejection`. + threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16, + used to avoid sampling low score tokens. + topp_seed(Tensor|None, optional): A 1-D Tensor with type int64, + used to specify the random seed for each query. + seed(int, optional): the random seed. Default is -1, + k(int): the number of top_k scores/ids to be returned. Default is 0. + Only used when FD_SAMPLING_CLASS is `air`. + mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value. + If the mode is `non-truncated`, it will not be truncated. Default is `truncated`. + Only used when FD_SAMPLING_CLASS is `air` or `base`. + order(str): The order of applying top-k and top-p sampling, should be either `top_k_first` or `joint`. + If `top_k_first`, we first apply top-k filter, then apply top-p sampling on the top-k results. + If `joint`, we apply top-k and top-p filter simultaneously in each round. Default is `top_k_first`. + Only used when FD_SAMPLING_CLASS is `rejection`. + """ top_p_class = envs.FD_SAMPLING_CLASS.lower() if top_p_class == "air": _, ids = air_top_p_sampling(x, - ps, + top_p, threshold, topp_seed, seed=seed, k=k, mode=mode) elif top_p_class == "rejection": - ids = rejection_top_p_sampling(x, ps, seed) + ids = rejection_top_p_sampling(x, top_p, top_k, seed, order) _ = None else: if current_platform.is_gcu(): - _, ids = gcu_top_p_sampling(x, ps) + _, ids = gcu_top_p_sampling(x, top_p) else: _, ids = paddle.tensor.top_p_sampling(x, - ps, + top_p, threshold=threshold, topp_seed=topp_seed, seed=seed, @@ -65,7 +87,7 @@ def top_p_sampling( def air_top_p_sampling( x: paddle.Tensor, - ps: paddle.Tensor, + top_p: paddle.Tensor, threshold: Optional[paddle.Tensor] = None, topp_seed: Optional[paddle.Tensor] = None, seed: int = -1, @@ -77,7 +99,7 @@ def air_top_p_sampling( """ try: from fastdeploy.model_executor.ops.gpu import air_top_p_sampling - out, ids = air_top_p_sampling(x, ps, threshold, topp_seed, seed, k, + out, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed, k, mode) except ImportError: raise RuntimeError("Cannot import air_top_p_sampling op.") @@ -86,19 +108,46 @@ def air_top_p_sampling( def rejection_top_p_sampling( x: paddle.Tensor, - ps: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, seed: int = -1, + order: Literal['top_k_first', 'joint'] = "top_k_first", ) -> paddle.Tensor: """ rejection_top_p_sampling """ + assert top_p is not None, "Top_p should not be none when FD_SAMPLING_CLASS is rejection" try: - from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling - ids = rejection_top_p_sampling( - x, - ps, - seed, - ) + from fastdeploy.model_executor.ops.gpu import ( + rejection_top_p_sampling, top_k_renorm_probs) + + if top_k is None: + ids = rejection_top_p_sampling( + x, + top_p, + None, + seed, + ) + elif top_k is not None and top_p is not None: + if order == "top_k_first": + renorm_probs = top_k_renorm_probs(x, top_k) + ids = rejection_top_p_sampling( + renorm_probs, + top_p, + None, + seed, + ) + else: + ids = rejection_top_p_sampling( + x, + top_p, + top_k, + seed, + ) + else: + raise ValueError( + "Top_p cannot be none." + ) except ImportError: raise RuntimeError("Cannot import rejection_top_p_sampling op.") return ids diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 2ee2a8fd1..b5e44af11 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -214,7 +214,7 @@ class Sampler(nn.Layer): probs = F.softmax(logits) - _, next_tokens = top_p_sampling(probs, sampling_metadata.top_p) + _, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) self.processor.update_output_tokens(next_tokens, skip_idx_list) return next_tokens @@ -367,5 +367,5 @@ class MTPSampler(nn.Layer): ) probs = F.softmax(logits) - _, next_tokens = top_p_sampling(probs, sampling_metadata.top_p) + _, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) return next_tokens diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 811b2b691..5955aad1c 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -154,12 +154,29 @@ class GCUModelRunner(ModelRunnerBase): -1].disaggregate_info["role"] == "prefill": os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1" + top_k_reqs = [] + top_p_reqs = [] + max_num_seqs = self.parallel_config.max_num_seqs + top_p_buffer = paddle.full([max_num_seqs, 1], + self.model_config.top_p, + dtype='float32') + top_k_buffer = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') + req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] idx = request.idx length = len(request.prompt_token_ids) + if sampling_params := request.sampling_params: + if sampling_params.top_p < 1: + top_p_reqs.append(idx) + top_k = sampling_params.top_k + if top_k > 0: + top_k_reqs.append(idx) + prefill_tokens = [] if (request.guided_json is not None or request.guided_regex is not None @@ -234,8 +251,8 @@ class GCUModelRunner(ModelRunnerBase): request.eos_token_ids.append(request.eos_token_ids[0]) self.share_inputs["eos_token_id"][:] = np.array( request.eos_token_ids, dtype="int64").reshape(-1, 1) - - self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) + top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0) + top_k_buffer[idx:idx + 1] = request.get("top_k", 0) self.share_inputs["temperature"][idx:idx + 1] = request.get( "temperature", 0.95) self.share_inputs["penalty_score"][idx:idx + 1] = request.get( @@ -286,6 +303,16 @@ class GCUModelRunner(ModelRunnerBase): if self.speculative_method in ["mtp"]: self.proposer.insert_prefill_inputs(req_dicts) + if len(top_k_reqs) == 0: + self.share_inputs["top_k"] = None + else: + self.share_inputs["top_k"] = top_k_buffer + + if len(top_p_reqs) == 0: + self.share_inputs["top_p"] = None + else: + self.share_inputs["top_p"] = top_p_buffer + def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): """ Set dummy prefill inputs to share_inputs """ @@ -340,8 +367,11 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["eos_token_id"] = paddle.full( [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], - self.model_config.top_p, - dtype='float32') + self.model_config.top_p, + dtype='float32') + self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype='float32') self.share_inputs["penalty_score"] = paddle.full( @@ -563,6 +593,7 @@ class GCUModelRunner(ModelRunnerBase): self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], + top_k=self.share_inputs["top_k"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a96db10ad..b49e0e837 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -161,6 +161,15 @@ class GPUModelRunner(ModelRunnerBase): -1].disaggregate_info["role"] == "prefill": os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1" + top_k_reqs = [] + top_p_reqs = [] + max_num_seqs = self.parallel_config.max_num_seqs + top_p_buffer = paddle.full([max_num_seqs, 1], + self.model_config.top_p, + dtype='float32') + top_k_buffer = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] @@ -168,6 +177,13 @@ class GPUModelRunner(ModelRunnerBase): length = len(request.prompt_token_ids) assert length > 0, "The prompt requested must not be empty." + if sampling_params := request.sampling_params: + if sampling_params.top_p < 1: + top_p_reqs.append(idx) + top_k = sampling_params.top_k + if top_k > 0: + top_k_reqs.append(idx) + prefill_tokens = [] if (request.guided_json is not None or request.guided_regex is not None @@ -242,8 +258,8 @@ class GPUModelRunner(ModelRunnerBase): request.eos_token_ids.append(request.eos_token_ids[0]) self.share_inputs["eos_token_id"][:] = np.array( request.eos_token_ids, dtype="int64").reshape(-1, 1) - - self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) + top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0) + top_k_buffer[idx:idx + 1] = request.get("top_k", 0) self.share_inputs["temperature"][idx:idx + 1] = request.get( "temperature", 0.95) self.share_inputs["penalty_score"][idx:idx + 1] = request.get( @@ -294,6 +310,16 @@ class GPUModelRunner(ModelRunnerBase): if self.speculative_method in ["mtp"]: self.proposer.insert_prefill_inputs(req_dicts) + if len(top_k_reqs) == 0: + self.share_inputs["top_k"] = None + else: + self.share_inputs["top_k"] = top_k_buffer + + if len(top_p_reqs) == 0: + self.share_inputs["top_p"] = None + else: + self.share_inputs["top_p"] = top_p_buffer + def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): """ Set dummy prefill inputs to share_inputs """ @@ -349,8 +375,11 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["eos_token_id"] = paddle.full( [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], - self.model_config.top_p, - dtype='float32') + self.model_config.top_p, + dtype='float32') + self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype='float32') self.share_inputs["penalty_score"] = paddle.full( @@ -574,6 +603,7 @@ class GPUModelRunner(ModelRunnerBase): self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], + top_k=self.share_inputs["top_k"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index 42aadd9b6..77f92676c 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -29,9 +29,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import \ AttentionBackend from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata -from fastdeploy.model_executor.layers.sample.sampler import (Sampler, - SpeculativeSampler - ) +from fastdeploy.model_executor.layers.sample.sampler import ( + Sampler, SpeculativeSampler) from fastdeploy.model_executor.model_loader import get_model_from_loader from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx from fastdeploy.model_executor.pre_and_post_process import (post_process, @@ -145,12 +144,29 @@ class IluvatarModelRunner(ModelRunnerBase): -1].disaggregate_info["role"] == "prefill": os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1" + top_k_reqs = [] + top_p_reqs = [] + max_num_seqs = self.parallel_config.max_num_seqs + top_p_buffer = paddle.full([max_num_seqs, 1], + self.model_config.top_p, + dtype='float32') + top_k_buffer = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') + req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] idx = request.idx length = len(request.prompt_token_ids) + if sampling_params := request.sampling_params: + if sampling_params.top_p < 1: + top_p_reqs.append(idx) + top_k = sampling_params.top_k + if top_k > 0: + top_k_reqs.append(idx) + prefill_tokens = [] if (request.guided_json is not None or request.guided_regex is not None @@ -225,8 +241,8 @@ class IluvatarModelRunner(ModelRunnerBase): request.eos_token_ids.append(request.eos_token_ids[0]) self.share_inputs["eos_token_id"][:] = np.array( request.eos_token_ids, dtype="int64").reshape(-1, 1) - - self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) + top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0) + top_k_buffer[idx:idx + 1] = request.get("top_k", 0) self.share_inputs["temperature"][idx:idx + 1] = request.get( "temperature", 0.95) self.share_inputs["penalty_score"][idx:idx + 1] = request.get( @@ -273,6 +289,15 @@ class IluvatarModelRunner(ModelRunnerBase): idx, request.get("logits_processor"), prefill_tokens) self.share_inputs["not_need_stop"][0] = True + if len(top_k_reqs) == 0: + self.share_inputs["top_k"] = None + else: + self.share_inputs["top_k"] = top_k_buffer + + if len(top_p_reqs) == 0: + self.share_inputs["top_p"] = None + else: + self.share_inputs["top_p"] = top_p_buffer def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): @@ -329,8 +354,11 @@ class IluvatarModelRunner(ModelRunnerBase): self.share_inputs["eos_token_id"] = paddle.full( [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], - self.model_config.top_p, - dtype='float32') + self.model_config.top_p, + dtype='float32') + self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype='float32') self.share_inputs["penalty_score"] = paddle.full( @@ -558,6 +586,7 @@ class IluvatarModelRunner(ModelRunnerBase): self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], + top_k=self.share_inputs["top_k"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/fastdeploy/worker/vl_model_runner_base.py b/fastdeploy/worker/vl_model_runner_base.py index d6d8cc4f8..7604053a5 100644 --- a/fastdeploy/worker/vl_model_runner_base.py +++ b/fastdeploy/worker/vl_model_runner_base.py @@ -14,14 +14,14 @@ # limitations under the License. """ -from abc import ABC, abstractmethod import argparse +from abc import ABC, abstractmethod import paddle import paddle.distributed as dist import paddle.distributed.fleet as fleet -from fastdeploy.config import ModelConfig +from fastdeploy.config import ModelConfig from fastdeploy.utils import get_logger logger = get_logger("worker", "worker.log") diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index b82eda700..5b6a4cd6b 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -282,11 +282,26 @@ class XPUModelRunner(ModelRunnerBase): def process_prefill_inputs(self, req_dicts: List[Request]): """ Process inputs for prefill tasks and update share_inputs buffer """ + top_k_reqs = [] + top_p_reqs = [] + max_num_seqs = self.parallel_config.max_num_seqs + top_p_buffer = paddle.full([max_num_seqs, 1], + self.model_config.top_p, + dtype='float32') + top_k_buffer = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] idx = request.idx length = request.prompt_token_ids_len + if sampling_params := request.sampling_params: + if sampling_params.top_p < 1: + top_p_reqs.append(idx) + top_k = sampling_params.top_k + if top_k > 0: + top_k_reqs.append(idx) self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array( request.prompt_token_ids) if len(request.eos_token_ids @@ -295,7 +310,8 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["eos_token_id"][:] = np.array( request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["pre_ids"][idx:idx + 1] = -1 - self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) + top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0) + top_k_buffer[idx:idx + 1] = request.get("top_k", 0) self.share_inputs["temperature"][idx:idx + 1] = request.get( "temperature", 0.95) self.share_inputs["penalty_score"][idx:idx + 1] = request.get( @@ -344,6 +360,15 @@ class XPUModelRunner(ModelRunnerBase): request.get("stop_token_ids"), dtype="int64") self.share_inputs["not_need_stop"][0] = True + if len(top_k_reqs) == 0: + self.share_inputs["top_k"] = None + else: + self.share_inputs["top_k"] = top_k_buffer + + if len(top_p_reqs) == 0: + self.share_inputs["top_p"] = None + else: + self.share_inputs["top_p"] = top_p_buffer def _init_share_inputs(self, max_num_seqs: int): """Initialize all share buffers for model inputs. @@ -363,8 +388,11 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["eos_token_id"] = paddle.full( [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], - self.model_config.top_p, - dtype='float32') + self.model_config.top_p, + dtype='float32') + self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype='float32') self.share_inputs["penalty_score"] = paddle.full( @@ -514,6 +542,7 @@ class XPUModelRunner(ModelRunnerBase): self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], + top_k=self.share_inputs["top_k"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"],