mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -25,6 +25,7 @@ struct AppendAttnMetaData {
|
||||
int kv_num_heads;
|
||||
int token_nums;
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
};
|
||||
|
||||
@@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 512: { \
|
||||
constexpr size_t HEAD_DIM = 512; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 576: { \
|
||||
constexpr size_t HEAD_DIM = 576; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the num_stage: ", num_stage); \
|
||||
}
|
||||
|
||||
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
|
||||
@@ -328,10 +375,13 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
|
||||
constexpr size_t cache_bytes = 4; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the cache_type: ", cache_type); \
|
||||
}
|
||||
|
||||
|
||||
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
|
||||
if (deal_each_time == 32) { \
|
||||
if (deal_each_time == 32) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 32; \
|
||||
__VA_ARGS__ \
|
||||
} else if (deal_each_time == 64) { \
|
||||
@@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 128) { \
|
||||
constexpr size_t GROUP_SIZE = 128; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
||||
if (block_shape_q <= 16) { \
|
||||
constexpr size_t BLOCK_SHAPE_Q = 16; \
|
||||
|
||||
Reference in New Issue
Block a user