[MTP]support more branchs in topp kernel (#4353)

This commit is contained in:
freeliuzc
2025-10-11 11:27:35 +08:00
committed by GitHub
parent 28aa18bfc1
commit 5035dd82ed

View File

@@ -75,12 +75,25 @@ struct BlockPrefixCallbackOp {
} break
#define FIXED_TOPK(...) \
FIXED_TOPK_BASE(1, ##__VA_ARGS__); \
FIXED_TOPK_BASE(2, ##__VA_ARGS__); \
FIXED_TOPK_BASE(3, ##__VA_ARGS__); \
FIXED_TOPK_BASE(4, ##__VA_ARGS__); \
FIXED_TOPK_BASE(5, ##__VA_ARGS__); \
FIXED_TOPK_BASE(6, ##__VA_ARGS__); \
FIXED_TOPK_BASE(7, ##__VA_ARGS__); \
FIXED_TOPK_BASE(8, ##__VA_ARGS__); \
FIXED_TOPK_BASE(10, ##__VA_ARGS__)
FIXED_TOPK_BASE(9, ##__VA_ARGS__); \
FIXED_TOPK_BASE(10, ##__VA_ARGS__); \
FIXED_TOPK_BASE(20, ##__VA_ARGS__); \
FIXED_TOPK_BASE(30, ##__VA_ARGS__); \
FIXED_TOPK_BASE(40, ##__VA_ARGS__); \
FIXED_TOPK_BASE(50, ##__VA_ARGS__); \
FIXED_TOPK_BASE(60, ##__VA_ARGS__); \
FIXED_TOPK_BASE(70, ##__VA_ARGS__); \
FIXED_TOPK_BASE(80, ##__VA_ARGS__); \
FIXED_TOPK_BASE(90, ##__VA_ARGS__); \
FIXED_TOPK_BASE(100, ##__VA_ARGS__);
struct SegmentOffsetIter {
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
@@ -502,8 +515,9 @@ void DispatchTopK(const T* src,
max_seq_len));
default:
PD_THROW(
"the input data shape has error in the topp_beam_topk "
"kernel.");
"Invalid max_candidate_len. Please set a value in [1,10] (step=1) "
"or [10,100] (step=10)."
);
});
default:
PD_THROW("the input topk is not implemented.");