mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-26 18:10:32 +08:00
[MTP]support more branchs in topp kernel (#4353)
This commit is contained in:
@@ -75,12 +75,25 @@ struct BlockPrefixCallbackOp {
|
||||
} break
|
||||
|
||||
#define FIXED_TOPK(...) \
|
||||
FIXED_TOPK_BASE(1, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(2, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(3, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(4, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(5, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(6, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(7, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(8, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(10, ##__VA_ARGS__)
|
||||
FIXED_TOPK_BASE(9, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(10, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(20, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(30, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(40, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(50, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(60, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(70, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(80, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(90, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(100, ##__VA_ARGS__);
|
||||
|
||||
struct SegmentOffsetIter {
|
||||
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
|
||||
@@ -502,8 +515,9 @@ void DispatchTopK(const T* src,
|
||||
max_seq_len));
|
||||
default:
|
||||
PD_THROW(
|
||||
"the input data shape has error in the topp_beam_topk "
|
||||
"kernel.");
|
||||
"Invalid max_candidate_len. Please set a value in [1,10] (step=1) "
|
||||
"or [10,100] (step=10)."
|
||||
);
|
||||
});
|
||||
default:
|
||||
PD_THROW("the input topk is not implemented.");
|
||||
|
||||
Reference in New Issue
Block a user