[MTP]support more branchs in topp kernel (#4353)

This commit is contained in:
freeliuzc
2025-10-11 11:27:35 +08:00
committed by GitHub
parent 28aa18bfc1
commit 5035dd82ed

View File

@@ -75,12 +75,25 @@ struct BlockPrefixCallbackOp {
} break } break
#define FIXED_TOPK(...) \ #define FIXED_TOPK(...) \
FIXED_TOPK_BASE(1, ##__VA_ARGS__); \
FIXED_TOPK_BASE(2, ##__VA_ARGS__); \ FIXED_TOPK_BASE(2, ##__VA_ARGS__); \
FIXED_TOPK_BASE(3, ##__VA_ARGS__); \ FIXED_TOPK_BASE(3, ##__VA_ARGS__); \
FIXED_TOPK_BASE(4, ##__VA_ARGS__); \ FIXED_TOPK_BASE(4, ##__VA_ARGS__); \
FIXED_TOPK_BASE(5, ##__VA_ARGS__); \ FIXED_TOPK_BASE(5, ##__VA_ARGS__); \
FIXED_TOPK_BASE(6, ##__VA_ARGS__); \
FIXED_TOPK_BASE(7, ##__VA_ARGS__); \
FIXED_TOPK_BASE(8, ##__VA_ARGS__); \ FIXED_TOPK_BASE(8, ##__VA_ARGS__); \
FIXED_TOPK_BASE(10, ##__VA_ARGS__) FIXED_TOPK_BASE(9, ##__VA_ARGS__); \
FIXED_TOPK_BASE(10, ##__VA_ARGS__); \
FIXED_TOPK_BASE(20, ##__VA_ARGS__); \
FIXED_TOPK_BASE(30, ##__VA_ARGS__); \
FIXED_TOPK_BASE(40, ##__VA_ARGS__); \
FIXED_TOPK_BASE(50, ##__VA_ARGS__); \
FIXED_TOPK_BASE(60, ##__VA_ARGS__); \
FIXED_TOPK_BASE(70, ##__VA_ARGS__); \
FIXED_TOPK_BASE(80, ##__VA_ARGS__); \
FIXED_TOPK_BASE(90, ##__VA_ARGS__); \
FIXED_TOPK_BASE(100, ##__VA_ARGS__);
struct SegmentOffsetIter { struct SegmentOffsetIter {
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
@@ -502,8 +515,9 @@ void DispatchTopK(const T* src,
max_seq_len)); max_seq_len));
default: default:
PD_THROW( PD_THROW(
"the input data shape has error in the topp_beam_topk " "Invalid max_candidate_len. Please set a value in [1,10] (step=1) "
"kernel."); "or [10,100] (step=10)."
);
}); });
default: default:
PD_THROW("the input topk is not implemented."); PD_THROW("the input topk is not implemented.");