mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 13:41:30 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -337,6 +337,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (deal_each_time == 64) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 64; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the deal_each_time", deal_each_time); \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \
|
||||
@@ -346,6 +348,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (num_threads == 256) { \
|
||||
constexpr size_t NUM_THREADS = 256; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the num_threads", num_threads); \
|
||||
}
|
||||
|
||||
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
@@ -376,6 +380,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (group_size == 12) { \
|
||||
constexpr size_t GROUP_SIZE = 12; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
||||
|
||||
Reference in New Issue
Block a user