diff --git a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu index a9e66862f..9caf30a32 100644 --- a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu +++ b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu @@ -75,12 +75,25 @@ struct BlockPrefixCallbackOp { } break #define FIXED_TOPK(...) \ + FIXED_TOPK_BASE(1, ##__VA_ARGS__); \ FIXED_TOPK_BASE(2, ##__VA_ARGS__); \ FIXED_TOPK_BASE(3, ##__VA_ARGS__); \ FIXED_TOPK_BASE(4, ##__VA_ARGS__); \ FIXED_TOPK_BASE(5, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(6, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(7, ##__VA_ARGS__); \ FIXED_TOPK_BASE(8, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(10, ##__VA_ARGS__) + FIXED_TOPK_BASE(9, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(10, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(20, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(30, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(40, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(50, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(60, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(70, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(80, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(90, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(100, ##__VA_ARGS__); struct SegmentOffsetIter { explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} @@ -502,8 +515,9 @@ void DispatchTopK(const T* src, max_seq_len)); default: PD_THROW( - "the input data shape has error in the topp_beam_topk " - "kernel."); + "Invalid max_candidate_len. Please set a value in [1,10] (step=1) " + "or [10,100] (step=10)." + ); }); default: PD_THROW("the input topk is not implemented.");