From 5035dd82ed9a9d6d46a6ea73ee5bb3ac06421ec2 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Sat, 11 Oct 2025 11:27:35 +0800 Subject: [PATCH] [MTP]support more branchs in topp kernel (#4353) --- .../speculate_decoding/top_p_candidates.cu | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu index a9e66862f..9caf30a32 100644 --- a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu +++ b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu @@ -75,12 +75,25 @@ struct BlockPrefixCallbackOp { } break #define FIXED_TOPK(...) \ + FIXED_TOPK_BASE(1, ##__VA_ARGS__); \ FIXED_TOPK_BASE(2, ##__VA_ARGS__); \ FIXED_TOPK_BASE(3, ##__VA_ARGS__); \ FIXED_TOPK_BASE(4, ##__VA_ARGS__); \ FIXED_TOPK_BASE(5, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(6, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(7, ##__VA_ARGS__); \ FIXED_TOPK_BASE(8, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(10, ##__VA_ARGS__) + FIXED_TOPK_BASE(9, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(10, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(20, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(30, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(40, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(50, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(60, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(70, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(80, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(90, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(100, ##__VA_ARGS__); struct SegmentOffsetIter { explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} @@ -502,8 +515,9 @@ void DispatchTopK(const T* src, max_seq_len)); default: PD_THROW( - "the input data shape has error in the topp_beam_topk " - "kernel."); + "Invalid max_candidate_len. Please set a value in [1,10] (step=1) " + "or [10,100] (step=10)." + ); }); default: PD_THROW("the input topk is not implemented.");