From 6c3d1da62f1fef75010374967d4b757c6e6c52af Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Thu, 13 Nov 2025 18:17:44 +0800 Subject: [PATCH] fix conflicts --- .../append_attn/append_attention_func.cuh | 135 ++--- .../multiquery_attention_c8_impl.cuh | 547 ++++++++++-------- 2 files changed, 381 insertions(+), 301 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index f09dbb99d..4c1c51dd4 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -383,56 +383,45 @@ __device__ __forceinline__ void produce_v_blockwise_c8( } } -template -__device__ __forceinline__ void produce_k_dynamic_scale( - T* k_smem_scale, - T* cache_k_reg, +__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async( + smem_t kv_scale_smem, const int* block_table_now, - const T* cache_k_scale, + const T* cache_kv_scale, const uint32_t kv_idx, const uint32_t kv_num_heads, const uint32_t kv_head_idx, const uint32_t chunk_end) { const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t tid = ty * 32 + tx; if constexpr (NUM_WARP_Q == 4) { // 4 warps shared block_size - const uint32_t tid = ty * 32 + tx; int block_id = __ldg(&block_table_now[kv_idx / block_size]); if (block_id < 0) block_id = 0; - const T* cache_k_scale_now = cache_k_scale + - block_id * kv_num_heads * block_size + - kv_head_idx * block_size; - if (tid < block_size) { - k_smem_scale[tid] = cache_k_scale_now[tid]; - } - __syncthreads(); - const uint32_t row_id = tx / 4; - for (uint32_t fz = 0; fz < num_frags_z; fz++) { - cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id]; - cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8]; + if (tid < block_size / 8) { + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid * 8; + const int kv_idx_this_thread = kv_idx + tid * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); } } else { // 1 warp 32 tokens - const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; - int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); - if (block_id < 0) block_id = 0; - const T* cache_k_scale_now = cache_k_scale + - block_id * kv_num_heads * block_size + - kv_head_idx * block_size; - const int kv_idx_this_thread = kv_idx + ty * 32 + tx; - if (kv_idx_this_thread < chunk_end) { - k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx]; - } else { - k_smem_scale[ty * 32 + tx] = 0; - } - __syncwarp(); - const uint32_t row_id = tx / 4; - for (uint32_t fz = 0; fz < num_frags_z; fz++) { - cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id]; - cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8]; + if (tid < block_size / 8 * 2) { + const uint32_t kv_idx_now = kv_idx + block_size * tid / 8; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const int kv_idx_this_thread = kv_idx + tid * 8; + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid % 8 * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); } } } @@ -441,57 +430,55 @@ template -__device__ __forceinline__ void produce_v_dynamic_scale( - T* v_smem_scale, - T* cache_v_reg, - const int* block_table_now, - const T* cache_v_scale, - const uint32_t kv_idx, - const uint32_t kv_num_heads, - const uint32_t kv_head_idx, - const uint32_t chunk_end) { +__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg( + T* k_smem_scale, T* cache_k_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + if constexpr (NUM_WARP_Q == 4) { + // 4 warps shared block_size + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } + } else { + // 1 warp 32 tokens + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } + } +} + +template +__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg( + T* v_smem_scale, T* cache_v_reg) { const uint32_t tx = threadIdx.x, ty = threadIdx.y; if constexpr (NUM_WARP_Q == 4) { // 4 warps shared block_size - const uint32_t tid = ty * 32 + tx; - int block_id = __ldg(&block_table_now[kv_idx / block_size]); - if (block_id < 0) block_id = 0; - const T* cache_v_scale_now = cache_v_scale + - block_id * kv_num_heads * block_size + - kv_head_idx * block_size; - if (tid < block_size) { - v_smem_scale[tid] = cache_v_scale_now[tid]; - } - __syncthreads(); const uint32_t row_id = tx % 4 * 2; for (uint32_t fz = 0; fz < num_frags_z; fz++) { - cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id]; - cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1]; - cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8]; - cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9]; + const uint32_t scale_idx = fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; } } else { // 1 warp 32 tokens - const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; - int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); - if (block_id < 0) block_id = 0; - const T* cache_v_scale_now = cache_v_scale + - block_id * kv_num_heads * block_size + - kv_head_idx * block_size; - const int kv_idx_this_thread = kv_idx + ty * 32 + tx; - if (kv_idx_this_thread < chunk_end) { - v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx]; - } else { - v_smem_scale[ty * 32 + tx] = 0; - } - __syncwarp(); const uint32_t row_id = tx % 4 * 2; for (uint32_t fz = 0; fz < num_frags_z; fz++) { - cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id]; - cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1]; - cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8]; - cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9]; + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; } } } diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 8d61cb847..01adb2130 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -34,15 +34,17 @@ template __global__ void multi_query_append_attention_c8_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] - const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, + // num_heads, block_size] + const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, + // num_heads, block_size] const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] + const T *__restrict__ sinks, // [q_num_heads] const int *__restrict__ seq_lens, const int *__restrict__ seq_lens_kv, const int *__restrict__ batch_ids, @@ -89,8 +91,8 @@ __global__ void multi_query_append_attention_c8_kernel( block_table_now = block_table + batch_id * max_block_num_per_seq; - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { return; } @@ -190,7 +192,8 @@ __global__ void multi_query_append_attention_c8_kernel( } else { o_base_ptr_int8 = out + o_offset; } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -211,15 +214,19 @@ __global__ void multi_query_append_attention_c8_kernel( smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); - T* k_smem_scale = nullptr; - T* v_smem_scale = nullptr; + T *k_smem_scale_ptr = nullptr; + T *v_smem_scale_ptr = nullptr; + smem_t k_scale_smem; + smem_t v_scale_smem; if constexpr (IsDynamicC8) { - k_smem_scale = reinterpret_cast(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); - v_smem_scale = k_smem_scale + num_frags_z * 16; + k_smem_scale_ptr = reinterpret_cast( + smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); } - const uint32_t num_iterations = div_up( CAUSAL ? (min(chunk_len, @@ -230,12 +237,13 @@ __global__ void multi_query_append_attention_c8_kernel( : chunk_len, num_frags_z * 16); const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, + (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : mask_offset ? 0 : chunk_len) / + : mask_offset ? 0 + : chunk_len) / (num_frags_z * 16); uint32_t k_smem_offset_r = @@ -248,8 +256,7 @@ __global__ void multi_query_append_attention_c8_kernel( uint32_t k_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, - tid % 8); + wid * 4 + tid / 8, tid % 8); uint32_t v_smem_offset_w = smem_t::get_permuted_offset( wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64 @@ -278,6 +285,18 @@ __global__ void multi_query_append_attention_c8_kernel( kv_idx_base, chunk_end, const_k_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } commit_group(); produce_v_blockwise_c8(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - if constexpr (IsDynamicC8) { - produce_k_dynamic_scale( - k_smem_scale, - cache_k_scale_reg, - block_table_now, - cache_k_scale, - kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } wait_group<1>(); __syncthreads(); + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg( + k_smem_scale_ptr, cache_k_scale_reg); + } // s = qk - compute_qk_c8( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - cache_k_scale_reg, - s_frag); + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); // mask according to kv_idx and q_idx if (iter >= mask_check_iteration || sliding_window > 0) { @@ -366,21 +395,25 @@ __global__ void multi_query_append_attention_c8_kernel( kv_idx_base, chunk_end, const_k_offset); - commit_group(); if constexpr (IsDynamicC8) { - produce_v_dynamic_scale( - v_smem_scale, - cache_v_scale_reg, - block_table_now, - cache_v_scale, - ori_kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); } + commit_group(); wait_group<1>(); __syncthreads(); + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg( + v_smem_scale_ptr, cache_v_scale_reg); + } // compute sfm*v compute_sfm_v_c8(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } commit_group(); - } wait_group<0>(); __syncthreads(); @@ -420,15 +464,19 @@ __global__ void multi_query_append_attention_c8_kernel( if constexpr (!partition_kv) { if (sinks) { float current_sinks[num_frags_x][2]; - #pragma unroll +#pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - #pragma unroll +#pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE; - current_sinks[fx][j] = static_cast(sinks[q_head_idx + h_offset]); + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); } } - normalize_d(o_frag, d_frag, m_frag, current_sinks); + normalize_d( + o_frag, d_frag, m_frag, current_sinks); } else { normalize_d(o_frag, d_frag); } @@ -474,7 +522,6 @@ __global__ void multi_query_append_attention_c8_kernel( HEAD_DIM); } - if constexpr (partition_kv) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { @@ -520,16 +567,18 @@ template __global__ void multi_query_append_attention_c8_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] - const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size] + const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, + // num_heads, block_size] + const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, + // num_heads, block_size] const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] const T *__restrict__ sinks, // [q_num_heads] @@ -540,7 +589,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, - const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -580,8 +629,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { return; } @@ -678,7 +727,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( tid % 8 * num_elems_per_128b(); } } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -703,12 +753,17 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); - T* k_smem_scale = nullptr; - T* v_smem_scale = nullptr; + T *k_smem_scale_ptr = nullptr; + T *v_smem_scale_ptr = nullptr; + smem_t k_scale_smem; + smem_t v_scale_smem; if constexpr (IsDynamicC8) { - k_smem_scale = reinterpret_cast(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); - v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16; + k_smem_scale_ptr = reinterpret_cast( + smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); } const uint32_t num_iterations = div_up( @@ -721,12 +776,13 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( : chunk_len, NUM_WARP_KV * num_frags_z * 16); const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, + (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : mask_offset ? 0 : chunk_len) / + : mask_offset ? 0 + : chunk_len) / (NUM_WARP_KV * num_frags_z * 16); uint32_t k_smem_offset_r = @@ -740,9 +796,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( uint32_t k_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, - tid % - 8); + wid * 4 + tid / 8, tid % 8); uint32_t v_smem_offset_w = smem_t::get_permuted_offset( wid * 8 + tid / 4, tid % 4); @@ -772,6 +826,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( kv_idx_base, chunk_end, const_k_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } commit_group(); produce_v_blockwise_c8(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - if constexpr (IsDynamicC8) { - produce_k_dynamic_scale( - k_smem_scale, - cache_k_scale_reg, - block_table_now, - cache_k_scale, - kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); - } wait_group<1>(); __syncthreads(); + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg( + k_smem_scale_ptr, cache_k_scale_reg); + } // s = qk - compute_qk_c8( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - cache_k_scale_reg, - s_frag); + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); // mask according to kv_idx and q_idx if (iter >= mask_check_iteration || sliding_window > 0) { mask_s(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, - q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - attn_mask_len, - s_frag, - mask_offset_this_seq, - sliding_window); - + num_frags_z>( + attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len + : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq, + sliding_window); } // update m,d @@ -860,21 +937,25 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( kv_idx_base, chunk_end, const_k_offset); - commit_group(); if constexpr (IsDynamicC8) { - produce_v_dynamic_scale( - v_smem_scale, - cache_v_scale_reg, - block_table_now, - cache_v_scale, - ori_kv_idx_base, - kv_num_heads, - kv_head_idx, - chunk_end - ); + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); } + commit_group(); wait_group<1>(); __syncthreads(); + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg( + v_smem_scale_ptr, cache_v_scale_reg); + } // compute sfm * v compute_sfm_v_c8_iter_sq_bvec(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } commit_group(); } wait_group<0>(); @@ -916,15 +1009,19 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( if (num_chunks_this_seq <= 1) { if (sinks) { float current_sinks[num_frags_x][2]; - #pragma unroll +#pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - #pragma unroll +#pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE; - current_sinks[fx][j] = static_cast(sinks[q_head_idx + h_offset]); + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); } } - normalize_d(o_frag, d_frag, m_frag, current_sinks); + normalize_d( + o_frag, d_frag, m_frag, current_sinks); } else { normalize_d(o_frag, d_frag); } @@ -981,7 +1078,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; if (ENABLE_PREFILL) { offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + @@ -1095,25 +1191,24 @@ void MultiQueryAppendC8Attention( IsFP8, IsDynamicC8>; if (is_scale_channel_wise) { - split_kv_kernel = - multi_query_append_attention_c8_kernel; + split_kv_kernel = multi_query_append_attention_c8_kernel; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(split_kv_kernel, @@ -1152,24 +1247,24 @@ void MultiQueryAppendC8Attention( IsDynamicC8>; if (is_scale_channel_wise) { nosplit_kv_kernel = - multi_query_append_attention_c8_kernel; + multi_query_append_attention_c8_kernel; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(nosplit_kv_kernel, @@ -1190,8 +1285,8 @@ void MultiQueryAppendC8Attention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1253,8 +1348,8 @@ void MultiQueryAppendC8Attention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1305,8 +1400,8 @@ void MultiQueryAppendC8Attention( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1319,8 +1414,7 @@ void MultiQueryAppendC8Attention( } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1386,24 +1480,24 @@ void MultiQueryAppendC8Attention( IsDynamicC8>; if (is_scale_channel_wise) { split_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; + multi_query_append_attention_c8_warp1_4_kernel; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(split_kv_kernel, @@ -1421,9 +1515,9 @@ void MultiQueryAppendC8Attention( const int num_chunks = div_up(max_seq_len, chunk_size); uint32_t attn_mask_len; if (attn_mask) { - attn_mask_len = attn_mask.get().shape()[1]; + attn_mask_len = attn_mask.get().shape()[1]; } else { - attn_mask_len = -1; + attn_mask_len = -1; } dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); @@ -1450,24 +1544,24 @@ void MultiQueryAppendC8Attention( IsDynamicC8>; if (is_scale_channel_wise) { nosplit_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; + multi_query_append_attention_c8_warp1_4_kernel; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(nosplit_kv_kernel, @@ -1488,8 +1582,8 @@ void MultiQueryAppendC8Attention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1498,7 +1592,7 @@ void MultiQueryAppendC8Attention( block_table.data(), meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1561,14 +1655,14 @@ void MultiQueryAppendC8Attention( reinterpret_cast(const_cast(cache_k_scale.data())), reinterpret_cast(const_cast(cache_v_scale.data())), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1577,7 +1671,7 @@ void MultiQueryAppendC8Attention( block_table.data(), meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1611,14 +1705,14 @@ void MultiQueryAppendC8Attention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1631,15 +1725,14 @@ void MultiQueryAppendC8Attention( } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1650,14 +1743,14 @@ void MultiQueryAppendC8Attention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound,