mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
format code (#4720)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -142,7 +142,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(
|
||||
const uint32_t tx_offset = tx / 8;
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
|
||||
const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset;
|
||||
#pragma unroll
|
||||
const int j = ty;
|
||||
@@ -151,8 +150,7 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(
|
||||
const uint32_t h_offset = offset_now % group_size;
|
||||
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
|
||||
q_smem_offset_w =
|
||||
@@ -171,7 +169,7 @@ template <uint32_t group_size,
|
||||
uint32_t HEAD_DIM,
|
||||
typename T>
|
||||
__device__ __forceinline__ void load_q_global_smem(
|
||||
T* q_ptr_base,
|
||||
const T* q_ptr_base,
|
||||
smem_t* q_smem,
|
||||
uint32_t q_idx_base,
|
||||
const uint32_t qo_upper_bound,
|
||||
@@ -194,10 +192,10 @@ __device__ __forceinline__ void load_q_global_smem(
|
||||
const uint32_t offset_now = base_offset + j * 4;
|
||||
const uint32_t n_offset = offset_now / group_size;
|
||||
const uint32_t h_offset = offset_now % group_size;
|
||||
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
|
||||
const T* q_ptr =
|
||||
q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
|
||||
q_smem_offset_w =
|
||||
@@ -223,8 +221,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps(
|
||||
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024;
|
||||
++i) {
|
||||
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) {
|
||||
const int offset = i * 1024 + ty * 256 + tx * 8;
|
||||
Load<T, vec_size>(reinterpret_cast<T*>(q_smem->base) + offset, &tmp_vec);
|
||||
#pragma unroll
|
||||
@@ -289,11 +286,9 @@ __device__ __forceinline__ void produce_kv_blockwise(
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps;
|
||||
++i) {
|
||||
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_y / 4;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
|
||||
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
|
||||
*gptr += 8 * num_elems_per_128b<T>();
|
||||
@@ -332,9 +327,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
block_size / num_elems_per_128b<CacheT>(); // 8
|
||||
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
uint32_t kv_idx =
|
||||
kv_idx_base +
|
||||
tx % 4 * num_elems_per_128b<CacheT>();
|
||||
uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b<CacheT>();
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
@@ -343,8 +336,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
|
||||
++i) { // m (num_frags_y * 16 / (num_warps * 8))
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_z / 4;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_z / 4; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
|
||||
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
|
||||
*smem_offset, j);
|
||||
@@ -369,8 +361,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
|
||||
++i) { // m (num_frags_y * 16 / (num_warps * 8))
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2 * num_frags_z / 4;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
|
||||
*smem_offset =
|
||||
smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
|
||||
@@ -392,27 +383,28 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
template <uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const T* cache_k_scale_now = cache_k_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
||||
}
|
||||
@@ -427,10 +419,12 @@ __device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const T* cache_k_scale_now = cache_k_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
k_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
@@ -443,20 +437,19 @@ __device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
template <uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
@@ -464,7 +457,9 @@ __device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const T* cache_v_scale_now = cache_v_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
||||
}
|
||||
@@ -481,10 +476,12 @@ __device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const T* cache_v_scale_now = cache_v_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
v_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
@@ -560,8 +557,7 @@ __device__ __forceinline__ void produce_k_blockwise_c8(
|
||||
for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps;
|
||||
++i) { // m num_frags_z * 16 / (num_warps * 4)
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_y / 8;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_y / 8; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
|
||||
*smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>(
|
||||
*smem_offset, j);
|
||||
@@ -614,8 +610,7 @@ __device__ __forceinline__ void produce_v_blockwise_c4(
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_z / 4;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_z / 4; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
|
||||
*smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>(
|
||||
*smem_offset, j);
|
||||
@@ -671,8 +666,7 @@ __device__ __forceinline__ void produce_k_blockwise_c4(
|
||||
for (uint32_t i = 0; i < num_frags_z * 2 / num_warps;
|
||||
++i) { // m num_frags_z * 16 / (num_warps * 8)
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_y / 8;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_y / 8; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
|
||||
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>(
|
||||
*smem_offset, j);
|
||||
@@ -937,7 +931,7 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
uint32_t* q_smem_offset_r,
|
||||
smem_t* k_smem,
|
||||
uint32_t* k_smem_offset_r,
|
||||
const T *cache_k_scale,
|
||||
const T* cache_k_scale,
|
||||
float (*s_frag)[num_frags_z][8]) {
|
||||
constexpr uint32_t head_dim = num_frags_y * 16;
|
||||
constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b<T>();
|
||||
@@ -973,8 +967,8 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < 2; ++fy) {
|
||||
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
@@ -1036,7 +1030,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t chunk_end,
|
||||
const uint32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr,
|
||||
const int* mask_offset = nullptr,
|
||||
const int sliding_window = 0) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
#pragma unroll
|
||||
@@ -1053,24 +1047,25 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
||||
}
|
||||
else if (sliding_window > 0)
|
||||
{
|
||||
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
out_of_boundary = q_idx < qo_len
|
||||
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
|
||||
kv_idx < mask_offset[q_idx * 2])
|
||||
: true;
|
||||
} else if (sliding_window > 0) {
|
||||
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx -
|
||||
(int)qo_len -
|
||||
sliding_window;
|
||||
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
|
||||
out_of_window || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
} else {
|
||||
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
|
||||
(kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len &&
|
||||
kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
const int32_t mask_idx =
|
||||
q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
bool mask = attn_mask[mask_idx];
|
||||
out_of_boundary |= mask;
|
||||
}
|
||||
@@ -1236,7 +1231,7 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
float (*o_frag)[num_frags_y][8],
|
||||
float (*d)[2],
|
||||
const T *cache_v_scale) {
|
||||
const T* cache_v_scale) {
|
||||
constexpr uint32_t num_vecs_per_blocksize =
|
||||
block_size / num_elems_per_128b<CacheT>();
|
||||
T s_frag_f16[num_frags_x][num_frags_z][8];
|
||||
@@ -1268,8 +1263,8 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < 2; ++fz) {
|
||||
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
@@ -1300,7 +1295,6 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
o_frag[fx][fy],
|
||||
(uint32_t*)(s_frag_f16[fx][kz * 2 + fz]),
|
||||
b_frag_dq);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1328,7 +1322,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
float (*o_frag)[num_frags_y][8],
|
||||
float (*d)[2],
|
||||
T *cache_v_scale) {
|
||||
T* cache_v_scale) {
|
||||
constexpr uint32_t num_vecs_per_blocksize =
|
||||
block_size / num_elems_per_128b<CacheT>();
|
||||
|
||||
@@ -1362,8 +1356,8 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
for (uint32_t fz = 0; fz < 2; ++fz) {
|
||||
// dequant b_frag -> b_frag_dq
|
||||
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
@@ -1372,7 +1366,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
@@ -1431,8 +1425,7 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem,
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z;
|
||||
++fz) {
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t b_frag[4];
|
||||
@@ -1611,10 +1604,9 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t o_frag_f16[4];
|
||||
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
|
||||
num_vecs_per_head>(
|
||||
fx * 16 + tx / 4,
|
||||
fy * 2);
|
||||
uint32_t o_smem_offset_w =
|
||||
smem_t::get_permuted_offset<num_vecs_per_head>(fx * 16 + tx / 4,
|
||||
fy * 2);
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w +
|
||||
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
|
||||
@@ -1627,8 +1619,8 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
ty * 4 + tx / 8, tx % 8);
|
||||
uint32_t o_smem_offset_w =
|
||||
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8, tx % 8);
|
||||
|
||||
o_idx_base += (tx / 8) / group_size;
|
||||
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
|
||||
@@ -1642,8 +1634,7 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
|
||||
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
|
||||
((fx * 16 + j * 4) % group_size) * qo_h_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
if (o_idx < qo_upper_bound) {
|
||||
// need write
|
||||
o_smem->store_128b(o_smem_offset_w, o_ptr);
|
||||
@@ -1658,7 +1649,6 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int VEC_SIZE, typename OutT>
|
||||
struct StoreFunc {
|
||||
__device__ __forceinline__ void operator()(
|
||||
@@ -1717,7 +1707,6 @@ struct StoreFunc<T, VEC_SIZE, __nv_fp8_e4m3> {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T, int VEC_SIZE>
|
||||
struct StoreFunc<T, VEC_SIZE, T> {
|
||||
__device__ __forceinline__ void operator()(
|
||||
@@ -1770,10 +1759,9 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t o_frag_f16[4];
|
||||
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
|
||||
num_vecs_per_head>(
|
||||
fx * 16 + tx / 4,
|
||||
fy * 2);
|
||||
uint32_t o_smem_offset_w =
|
||||
smem_t::get_permuted_offset<num_vecs_per_head>(fx * 16 + tx / 4,
|
||||
fy * 2);
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w +
|
||||
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
|
||||
@@ -1786,8 +1774,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
ty * 4 + tx / 8, tx % 8);
|
||||
uint32_t o_smem_offset_w =
|
||||
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8, tx % 8);
|
||||
|
||||
const uint32_t tx_offset = tx / 8;
|
||||
#pragma unroll
|
||||
@@ -1804,8 +1792,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
|
||||
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
|
||||
tx % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
if (n_offset < qo_upper_bound) {
|
||||
if constexpr (!partition_kv) {
|
||||
Load<T, VEC_SIZE>(
|
||||
@@ -1881,10 +1868,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t o_frag_f16[4];
|
||||
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
|
||||
num_vecs_per_head>(
|
||||
(ty * num_frags_x + fx) * 16 + tx / 4,
|
||||
fy * 2);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
(ty * num_frags_x + fx) * 16 + tx / 4, fy * 2);
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w +
|
||||
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
|
||||
@@ -1897,8 +1882,7 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
|
||||
__syncthreads();
|
||||
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
ty * num_frags_x * 16 + tx / 8,
|
||||
tx % 8);
|
||||
ty * num_frags_x * 16 + tx / 8, tx % 8);
|
||||
|
||||
const uint32_t tx_offset = tx / 8;
|
||||
#pragma unroll
|
||||
@@ -1914,13 +1898,12 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
|
||||
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
|
||||
tx % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
if (n_offset < qo_upper_bound) {
|
||||
if (!partition_kv) {
|
||||
Load<T, VEC_SIZE>(
|
||||
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
|
||||
&ori_out_vec);
|
||||
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
|
||||
&ori_out_vec);
|
||||
if (in_scale > 0.0) {
|
||||
if (shift_bias) {
|
||||
Load<T, VEC_SIZE>(shift_bias + shift_smooth_offset,
|
||||
@@ -1929,16 +1912,16 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
|
||||
&smooth_weight_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
StoreFunc<T, VEC_SIZE, OutT>()(ori_out_vec,
|
||||
shift_bias_vec,
|
||||
smooth_weight_vec,
|
||||
out_vec,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
i);
|
||||
shift_bias_vec,
|
||||
smooth_weight_vec,
|
||||
out_vec,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
i);
|
||||
}
|
||||
Store<OutT, VEC_SIZE>(out_vec, o_ptr);
|
||||
} else {
|
||||
@@ -1979,10 +1962,8 @@ __device__ __forceinline__ void write_o_reg_gmem(
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t o_frag_f16[4];
|
||||
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
|
||||
num_vecs_per_head>(
|
||||
(ty * num_frags_x + fx) * 16 + tx / 4,
|
||||
fy * 2);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
(ty * num_frags_x + fx) * 16 + tx / 4, fy * 2);
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w +
|
||||
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
|
||||
@@ -1995,8 +1976,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
|
||||
__syncthreads();
|
||||
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
ty * num_frags_x * 16 + tx / 8,
|
||||
tx % 8);
|
||||
ty * num_frags_x * 16 + tx / 8, tx % 8);
|
||||
|
||||
o_idx_base += (tx / 8) / group_size;
|
||||
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
|
||||
@@ -2009,8 +1989,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
|
||||
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
|
||||
((fx * 16 + j * 4) % group_size) * qo_h_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
if (o_idx < qo_upper_bound) {
|
||||
o_smem->store_128b(o_smem_offset_w, o_ptr);
|
||||
}
|
||||
@@ -2125,7 +2104,6 @@ __global__ void merge_multi_chunks_kernel(
|
||||
&out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
|
||||
__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8],
|
||||
float* md_smem,
|
||||
@@ -2307,18 +2285,18 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true>
|
||||
__global__ void merge_multi_chunks_decoder_kernel(
|
||||
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
// head_dim]
|
||||
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
|
||||
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
|
||||
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int* __restrict__ seq_lens_q,
|
||||
const int* __restrict__ seq_lens_kv,
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ sinks, // [q_num_heads]
|
||||
OutT* __restrict__ out,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
@@ -2419,8 +2397,14 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
StoreFunc<T, vec_size, OutT>()(
|
||||
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
|
||||
StoreFunc<T, vec_size, OutT>()(st.o,
|
||||
shift_bias_vec,
|
||||
smooth_weight_vec,
|
||||
out_vec,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
i);
|
||||
}
|
||||
Store<OutT, vec_size>(
|
||||
out_vec,
|
||||
@@ -2435,19 +2419,19 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true>
|
||||
__global__ void merge_multi_chunks_v2_kernel(
|
||||
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
// head_dim]
|
||||
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
|
||||
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
|
||||
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int* __restrict__ seq_lens_q,
|
||||
const int* __restrict__ seq_lens_kv,
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ sinks, // [q_num_heads]
|
||||
OutT* __restrict__ out,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
@@ -2464,7 +2448,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
if(bid == -1){
|
||||
if (bid == -1) {
|
||||
continue;
|
||||
}
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
@@ -2486,7 +2470,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
continue;
|
||||
}else if (!ENABLE_PREFILL){
|
||||
} else if (!ENABLE_PREFILL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -2496,12 +2480,12 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2 *)(&res_vec) + i) = make_half2(0, 0);
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
float m;
|
||||
@@ -2581,10 +2565,17 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
Load<T, vec_size>(smooth_weight + shift_smooth_offset,
|
||||
&smooth_weight_vec);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
StoreFunc<T, vec_size, OutT>()(
|
||||
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
|
||||
StoreFunc<T, vec_size, OutT>()(st.o,
|
||||
shift_bias_vec,
|
||||
smooth_weight_vec,
|
||||
out_vec,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
i);
|
||||
}
|
||||
Store<OutT, vec_size>(
|
||||
out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
|
||||
@@ -20,12 +20,11 @@
|
||||
#include "utils.cuh"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void
|
||||
GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_this_time_merged,
|
||||
const int *seq_lens_encoder_merged, const int *seq_mapping,
|
||||
const int *system_lens, int *max_lens, const int batch_size) {
|
||||
__global__ void GetMaxLenKernel(const int *seq_lens_decoder,
|
||||
const int *seq_lens_this_time,
|
||||
const int *seq_lens_encoder,
|
||||
int *max_lens,
|
||||
const int batch_size) {
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
|
||||
@@ -36,9 +35,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
||||
int max_len_decoder_this_thread = 0;
|
||||
int max_len_this_thread = 0;
|
||||
int max_just_dec_len_this_thread = 0;
|
||||
int max_just_dec_merged_len_this_time_this_thread = 0;
|
||||
int max_system_len_this_thread = 0;
|
||||
int max_dec_len_without_system_this_thread = 0;
|
||||
int max_len_kv_this_thread = 0;
|
||||
for (int i = tid; i < batch_size; i += blockDim.x) {
|
||||
const int seq_len_this_time = seq_lens_this_time[i];
|
||||
@@ -47,17 +43,17 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
||||
max(seq_len_this_time, max_len_this_time_this_thread);
|
||||
max_len_encoder_this_thread =
|
||||
max(seq_lens_encoder[i], max_len_encoder_this_thread);
|
||||
max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread);
|
||||
if (seq_len_this_time <= 0)
|
||||
continue;
|
||||
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder;
|
||||
max_len_decoder_this_thread =
|
||||
max(seq_len_decoder, max_len_decoder_this_thread);
|
||||
if (seq_len_this_time <= 0) continue;
|
||||
const int max_just_dec_len_now =
|
||||
seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder;
|
||||
max_len_this_thread =
|
||||
max(seq_len_decoder + seq_len_this_time, max_len_this_thread);
|
||||
max_just_dec_len_this_thread =
|
||||
max(max_just_dec_len_this_thread, max_just_dec_len_now);
|
||||
|
||||
if (seq_len_decoder == 0)
|
||||
continue;
|
||||
if (seq_len_decoder == 0) continue;
|
||||
max_len_kv_this_thread =
|
||||
max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread);
|
||||
}
|
||||
@@ -74,14 +70,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
||||
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
|
||||
int total_just_dec = BlockReduce(temp_storage)
|
||||
.Reduce(max_just_dec_len_this_thread, MaxOp<int>());
|
||||
int total_just_dec_merged =
|
||||
BlockReduce(temp_storage)
|
||||
.Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
|
||||
int total_system_len = BlockReduce(temp_storage)
|
||||
.Reduce(max_system_len_this_thread, MaxOp<int>());
|
||||
int total_dec_len_without_system =
|
||||
BlockReduce(temp_storage)
|
||||
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
|
||||
int total_max_len_kv =
|
||||
BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp<int>());
|
||||
if (tid == 0) {
|
||||
@@ -90,24 +78,10 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
||||
max_lens[2] = total_max_len_decoder;
|
||||
max_lens[3] = total;
|
||||
max_lens[4] = total_just_dec;
|
||||
max_lens[5] = total_just_dec_merged;
|
||||
max_lens[6] = total_system_len;
|
||||
max_lens[7] = total_dec_len_without_system;
|
||||
max_lens[8] = total_max_len_kv;
|
||||
}
|
||||
}
|
||||
|
||||
void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
paddle::Tensor &max_len_tensor, const int batch_size) {
|
||||
constexpr int blockSize = 1024;
|
||||
GetMaxLenKernel<blockSize><<<1, blockSize, 0, seq_lens_encoder.stream()>>>(
|
||||
seq_lens_tensor.data<int>(), seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(), nullptr, nullptr, nullptr, nullptr,
|
||||
max_len_tensor.data<int>(), batch_size);
|
||||
}
|
||||
|
||||
template <uint32_t config_size>
|
||||
__global__ void search_chunk_size_for_mla(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
@@ -154,11 +128,11 @@ __global__ void search_chunk_size_for_mla(
|
||||
uint32_t res_id = 0;
|
||||
uint32_t max_last_wave_block = 0;
|
||||
for (uint32_t i = 1; i < config_size; i++) {
|
||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
||||
if (last_wave_block >= max_last_wave_block) {
|
||||
res_id = i;
|
||||
max_last_wave_block = last_wave_block;
|
||||
}
|
||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
||||
if (last_wave_block >= max_last_wave_block) {
|
||||
res_id = i;
|
||||
max_last_wave_block = last_wave_block;
|
||||
}
|
||||
}
|
||||
*num_blocks_x = gridx_shared[res_id];
|
||||
*res_chunk_size = block_size << res_id;
|
||||
@@ -185,11 +159,11 @@ __global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
if (seq_len_encoder > 0) {
|
||||
loop_times = 0;
|
||||
loop_times = 0;
|
||||
}
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -255,8 +229,10 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
int *__restrict__ num_blocks_x, const int bsz,
|
||||
const int pad_len, const int num_row_per_block) {
|
||||
int *__restrict__ num_blocks_x,
|
||||
const int bsz,
|
||||
const int pad_len,
|
||||
const int num_row_per_block) {
|
||||
if (threadIdx.x == 0) {
|
||||
int gridx = 0;
|
||||
int index = 0;
|
||||
@@ -281,31 +257,39 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num)
|
||||
{
|
||||
const int decoder_step_token_num) {
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor_gpu, bsz);
|
||||
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
|
||||
paddle::Tensor max_len_tensor_gpu =
|
||||
GetEmptyTensor({max_len_tensor_cpu.shape()[0]},
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_this_time.place());
|
||||
|
||||
GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_len_tensor_gpu.data<int>(),
|
||||
bsz);
|
||||
|
||||
max_len_tensor_cpu.copy_(
|
||||
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
|
||||
|
||||
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
|
||||
int max_len_this_time = max_len_cpu_ptr[0];
|
||||
@@ -320,7 +304,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
|
||||
const bool mla_backend = checkAttentionBackend();
|
||||
if (mla_backend && group_size <= 64) {
|
||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
||||
@@ -356,8 +339,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
|
||||
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
|
||||
// decode_max_tile_size must take into account the maximum case, where *
|
||||
// 1024 can cover 128K. const uint32_t decoder_batch_shape =
|
||||
// seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
@@ -375,7 +359,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -419,49 +402,72 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t max_tile_size_per_bs_kv =
|
||||
div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
kv_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
auto kv_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
|
||||
split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>(
|
||||
seq_lens_decoder.data<int>(),
|
||||
// sequence_lengths->data<int>(),
|
||||
seq_lens_encoder.data<int>(), kv_batch_ids.data<int>(),
|
||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||
block_size, block_size);
|
||||
seq_lens_encoder.data<int>(),
|
||||
kv_batch_ids.data<int>(),
|
||||
kv_tile_ids_per_batch.data<int>(),
|
||||
kv_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
block_size,
|
||||
block_size);
|
||||
|
||||
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
||||
kv_num_blocks_x_cpu.copy_(
|
||||
kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
||||
// Clear buffer
|
||||
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(encoder_batch_ids.data<int>(),
|
||||
0,
|
||||
encoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
encoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
auto encoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(),
|
||||
nullptr,
|
||||
encoder_batch_ids.data<int>(),
|
||||
encoder_tile_ids_per_batch.data<int>(),
|
||||
encoder_num_blocks_x.data<int>(), bsz,
|
||||
encoder_block_shape_q, group_size);
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
encoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
encoder_block_shape_q,
|
||||
group_size);
|
||||
encoder_num_blocks_x_cpu.copy_(
|
||||
encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
@@ -472,8 +478,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num
|
||||
) {
|
||||
const int decoder_step_token_num) {
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -485,39 +490,36 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num
|
||||
) {
|
||||
const int decoder_step_token_num) {
|
||||
return {};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Inputs({
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_cpu",
|
||||
"decoder_num_blocks_device",
|
||||
"decoder_chunk_size_device",
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks_x_cpu",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks_x_cpu",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_cpu",
|
||||
"decoder_num_blocks_device",
|
||||
"decoder_chunk_size_device",
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks_x_cpu",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks_x_cpu",
|
||||
})
|
||||
.Outputs({
|
||||
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
"decoder_step_token_num: int"
|
||||
})
|
||||
.Attrs({"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
"decoder_step_token_num: int"})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
||||
|
||||
@@ -21,7 +21,6 @@ template <typename T,
|
||||
bool CAUSAL,
|
||||
uint32_t NUM_WARPS,
|
||||
uint32_t NUM_WARP_Q,
|
||||
uint32_t NUM_WARP_KV,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t num_frags_x,
|
||||
@@ -30,13 +29,14 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true>
|
||||
__global__ void multi_query_append_attention_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
const T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) *
|
||||
// head_dim]
|
||||
T *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
const T *__restrict__ cache_v,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const int *__restrict__ seq_lens,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
@@ -45,7 +45,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
const float scale,
|
||||
const float quant_max_bound,
|
||||
@@ -72,12 +71,10 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const uint32_t batch_id = batch_ids[btid];
|
||||
const uint32_t tile_id = tile_ids_per_batch[btid];
|
||||
const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16;
|
||||
const int *block_table_now = nullptr;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
// When cudagraph capture prefill, may launch more gridDim.x
|
||||
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -85,8 +82,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
|
||||
uint32_t kv_len = seq_lens_kv[batch_id];
|
||||
if (ENABLE_PREFILL) {
|
||||
kv_len += q_len;
|
||||
@@ -111,6 +107,9 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const uint32_t chunk_len = chunk_end - chunk_start;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
static_assert(num_frags_y * 16 == HEAD_DIM);
|
||||
static_assert(num_frags_z * 16 == BLOCK_SIZE);
|
||||
|
||||
float s_frag[num_frags_x][num_frags_z][8];
|
||||
float o_frag[num_frags_x][num_frags_y][8];
|
||||
float m_frag[num_frags_x][2];
|
||||
@@ -131,7 +130,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const uint32_t o_offset = q_start_seq_id * q_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
T *q_base_ptr = q + q_offset;
|
||||
const T *q_base_ptr = q + q_offset;
|
||||
T *o_base_ptr_T = nullptr;
|
||||
OutT *o_base_ptr_int8 = nullptr;
|
||||
if constexpr (partition_kv) {
|
||||
@@ -149,11 +148,16 @@ __global__ void multi_query_append_attention_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
const int *mask_offset_this_seq =
|
||||
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16
|
||||
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
|
||||
load_q_global_smem<GROUP_SIZE, num_frags_x, num_frags_y, HEAD_DIM, T>(
|
||||
q_base_ptr,
|
||||
&qo_smem,
|
||||
@@ -172,7 +176,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM *
|
||||
sizeof(T));
|
||||
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
CAUSAL
|
||||
? (min(chunk_len,
|
||||
@@ -183,12 +186,13 @@ __global__ void multi_query_append_attention_kernel(
|
||||
: chunk_len,
|
||||
num_frags_z * 16);
|
||||
const uint32_t mask_check_iteration =
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
: mask_offset ? 0
|
||||
: chunk_len) /
|
||||
(num_frags_z * 16);
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
@@ -204,8 +208,8 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const uint32_t const_offset = kv_head_idx * kv_h_stride +
|
||||
(wid * 4 + tid / 8) * kv_b_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
|
||||
T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
|
||||
const T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
|
||||
const T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
|
||||
|
||||
produce_kv_blockwise<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -264,7 +268,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
s_frag,
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -321,18 +324,22 @@ __global__ void multi_query_append_attention_kernel(
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (!partition_kv ) {
|
||||
if constexpr (!partition_kv) {
|
||||
if (sinks) {
|
||||
float current_sinks[num_frags_x][2];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
|
||||
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
const uint32_t h_offset =
|
||||
(q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) %
|
||||
GROUP_SIZE;
|
||||
current_sinks[fx][j] =
|
||||
static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
}
|
||||
}
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
|
||||
normalize_d<num_frags_x, num_frags_y>(
|
||||
o_frag, d_frag, m_frag, current_sinks);
|
||||
} else {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
}
|
||||
@@ -375,7 +382,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
HEAD_DIM);
|
||||
}
|
||||
|
||||
|
||||
if constexpr (partition_kv) {
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
@@ -421,13 +427,13 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true>
|
||||
__global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
T *__restrict__ cache_v,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
const int *__restrict__ seq_lens,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
@@ -435,9 +441,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
const float scale,
|
||||
const float quant_max_bound,
|
||||
@@ -469,8 +474,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
// When cudagraph capture prefill, may launch more gridDim.x
|
||||
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -478,8 +483,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
|
||||
uint32_t kv_len = seq_lens_kv[batch_id];
|
||||
if (ENABLE_PREFILL) {
|
||||
kv_len += q_len;
|
||||
@@ -540,11 +544,16 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
const int *mask_offset_this_seq =
|
||||
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
tid % 16, tid / 16); // 16 * 16
|
||||
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
|
||||
load_q_global_smem_multi_warps<GROUP_SIZE,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
@@ -576,11 +585,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
: chunk_len,
|
||||
NUM_WARP_KV * num_frags_z * 16);
|
||||
const uint32_t mask_check_iteration =
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(
|
||||
kv_len - q_len,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(kv_len - q_len, chunk_start)))
|
||||
: mask_offset ? 0
|
||||
: chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -648,16 +656,18 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
num_frags_z>(
|
||||
attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len
|
||||
: nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq,
|
||||
sliding_window);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -720,15 +730,19 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
if (sinks) {
|
||||
float current_sinks[num_frags_x][2];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
|
||||
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
const uint32_t h_offset =
|
||||
(q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) %
|
||||
GROUP_SIZE;
|
||||
current_sinks[fx][j] =
|
||||
static_cast<float>(sinks[q_head_idx + h_offset]);
|
||||
}
|
||||
}
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
|
||||
normalize_d<num_frags_x, num_frags_y>(
|
||||
o_frag, d_frag, m_frag, current_sinks);
|
||||
} else {
|
||||
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
|
||||
}
|
||||
@@ -876,7 +890,6 @@ void MultiQueryAppendAttention(
|
||||
CAUSAL,
|
||||
num_warps,
|
||||
NUM_WARP_Q,
|
||||
NUM_WARP_KV,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_frags_x,
|
||||
@@ -908,7 +921,6 @@ void MultiQueryAppendAttention(
|
||||
CAUSAL,
|
||||
num_warps,
|
||||
NUM_WARP_Q,
|
||||
NUM_WARP_KV,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_frags_x,
|
||||
@@ -933,8 +945,8 @@ void MultiQueryAppendAttention(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -943,7 +955,6 @@ void MultiQueryAppendAttention(
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
scale,
|
||||
quant_max_bound,
|
||||
@@ -996,8 +1007,8 @@ void MultiQueryAppendAttention(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1006,7 +1017,6 @@ void MultiQueryAppendAttention(
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
scale,
|
||||
quant_max_bound,
|
||||
@@ -1048,8 +1058,8 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1087,8 +1097,8 @@ void MultiQueryAppendAttention(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1138,9 +1148,9 @@ void MultiQueryAppendAttention(
|
||||
|
||||
uint32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
@@ -1179,8 +1189,8 @@ void MultiQueryAppendAttention(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1189,9 +1199,8 @@ void MultiQueryAppendAttention(
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
scale,
|
||||
quant_max_bound,
|
||||
@@ -1250,14 +1259,14 @@ void MultiQueryAppendAttention(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
@@ -1266,9 +1275,8 @@ void MultiQueryAppendAttention(
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
scale,
|
||||
quant_max_bound,
|
||||
@@ -1306,14 +1314,14 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -1326,15 +1334,14 @@ void MultiQueryAppendAttention(
|
||||
} else {
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1345,14 +1352,14 @@ void MultiQueryAppendAttention(
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(sinks.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
|
||||
Reference in New Issue
Block a user