Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -337,6 +337,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (deal_each_time == 64) { \
constexpr size_t DEAL_EACH_TIME = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the deal_each_time", deal_each_time); \
}
#define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \
@@ -346,6 +348,8 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (num_threads == 256) { \
constexpr size_t NUM_THREADS = 256; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the num_threads", num_threads); \
}
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
@@ -376,6 +380,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \