From a389bb7c5c896d358ce9e36b1bd59978507c2be7 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:10:17 +0800 Subject: [PATCH] [Feature][Optimization] Qwen Support Dynamic block_wise_fp8 cache (#5486) --- .../decoder_write_cache_with_rope_impl.cuh | 309 ++++++++++++++++++ .../decoder_write_cache_with_rope_kernel.cu | 109 ++++-- 2 files changed, 382 insertions(+), 36 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 5c141d7e3..5d1daed91 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -849,6 +849,315 @@ __global__ void append_decode_cache_T_quant_neox_rope_kernel( #endif } +template +__global__ void append_decode_cache_T_int8_neox_rope_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + T* __restrict__ cache_k_scale, + T* __restrict__ cache_v_scale, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int kv_num_heads, + const bool rope_3d, + const float rms_norm_eps) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int start_token_idx = cu_seqlens_q[bid]; + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadT src_vec_right; + LoadBiasT out_vec; + LoadBiasT out_vec_right; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < half_head_size; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + Load(&qkv_now[bias_idx + half_head_size], &src_vec_right); + // q rope + const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; + const uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[i]); + float input_right = static_cast(src_vec_right[i]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec[i] = static_cast(tmp1); + out_vec_right[i] = static_cast(tmp2); + } + Store(out_vec, &qkv_out_now[bias_idx]); + Store(out_vec_right, &qkv_out_now[bias_idx + half_head_size]); + } + } else if (head_idx < num_heads + 2 * kv_num_heads) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + kv_num_heads) { + constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store(pad_cache_vec, + &key_cache[tgt_idx + block_i * HeadDim]); + } + } else { + const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); + } + } + __syncwarp(); + } + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec1_right, src_vec2, src_vec2_right; + LoadBiasT out_vec1, out_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + T scale = T(1.0f); + const int k_head_idx = head_idx - num_heads; + const int v_head_idx = head_idx - num_heads - kv_num_heads; + if (head_idx < num_heads + kv_num_heads) { + Load( + &qkv_now[head_idx * HeadDim + (head_bias + half_head_size) % HeadDim], + &src_vec1_right); + Load( + &qkv_now[head_idx * HeadDim + + (head_bias + 8 + half_head_size) % HeadDim], + &src_vec2_right); + + const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; + const uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); + } + + if (head_idx < num_heads + kv_num_heads) { + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1_right[0]); + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + float tmp1 = 0; + if (head_bias < half_head_size) { + tmp1 = input_left * cos_tmp - input_right * sin_tmp; + } else { + tmp1 = input_left * cos_tmp + input_right * sin_tmp; + } + out_vec1[0] = static_cast(tmp1); + input_left = static_cast(src_vec1[1]); + input_right = static_cast(src_vec1_right[1]); + cos_tmp = cos_emb_vec1[1]; + sin_tmp = sin_emb_vec1[1]; + if (head_bias < half_head_size) { + tmp1 = input_left * cos_tmp - input_right * sin_tmp; + } else { + tmp1 = input_left * cos_tmp + input_right * sin_tmp; + } + out_vec1[1] = static_cast(tmp1); + } else { + out_vec1[0] = src_vec1[0]; + out_vec1[1] = src_vec1[1]; + } + + // rope + if (head_idx < num_heads + kv_num_heads) { + float input_left = static_cast(src_vec2[0]); + float input_right = static_cast(src_vec2_right[0]); + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + float tmp1 = 0; + if (head_bias < half_head_size) { + tmp1 = input_left * cos_tmp - input_right * sin_tmp; + } else { + tmp1 = input_left * cos_tmp + input_right * sin_tmp; + } + out_vec2[0] = static_cast(tmp1); + input_left = static_cast(src_vec2[1]); + input_right = static_cast(src_vec2_right[1]); + cos_tmp = cos_emb_vec2[1]; + sin_tmp = sin_emb_vec2[1]; + if (head_bias < half_head_size) { + tmp1 = input_left * cos_tmp - input_right * sin_tmp; + } else { + tmp1 = input_left * cos_tmp + input_right * sin_tmp; + } + out_vec2[1] = static_cast(tmp1); + } else { + out_vec2[0] = src_vec2[0]; + out_vec2[1] = src_vec2[1]; + } + if constexpr (IsDynamic) { + // reduce max, 1 head per warp + T local_max = -INFINITY; +#pragma unroll + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + local_max = __hmax(local_max, __habs(out_vec1[i])); + local_max = __hmax(local_max, __habs(out_vec2[i])); + } +#pragma unroll + for (int m_offset = 16; m_offset > 0; m_offset /= 2) { + local_max = + __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); + } + scale = __hdiv(448, local_max); + + int cache_offset; + if (head_idx < num_heads) { + cache_offset = 0; + } else if (head_idx < num_heads + 2 * kv_num_heads) { + cache_offset = block_idx * kv_num_heads * block_size + + (head_idx - num_heads) % kv_num_heads * block_size + + block_offset; + } + T* cache_k_scale_now = cache_k_scale + cache_offset; + T* cache_v_scale_now = cache_v_scale + cache_offset; + if (lane_id == 0) { + if (head_idx < num_heads + kv_num_heads) { + cache_k_scale_now[0] = __hdiv(1, scale); + } else { + cache_v_scale_now[0] = __hdiv(1, scale); + } + } + } else { + if (head_idx < num_heads + kv_num_heads) { + scale = __ldg(&cache_k_scale[kv_head_idx]); + } else { + scale = __ldg(&cache_v_scale[kv_head_idx]); + } + } + +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + cache_vec[i] = QuantToC8( + scale, out_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8( + scale, out_vec2[i], max_bound, min_bound); + } + if (head_idx < num_heads + kv_num_heads) { + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * kv_num_heads * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + const uint32_t base_tgt_cache_idx = + block_idx * kv_num_heads * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + template , - grids, - num_warps * 32, - 0, - stream, - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - const_cast(reinterpret_cast( - cache_k_scale.get().data())), - const_cast(reinterpret_cast( - (cache_v_scale.get().data()))), - nullptr, - nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d, - rms_norm_eps); + if (use_neox_rotary_style) { + launchWithPdlWhenEnabled( + append_decode_cache_T_int8_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); + } else { + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_qk_norm_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + nullptr, + nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); + } } else if (cache_quant_type_str == "cache_int4_zp") { append_decode_cache_int4_rope( reinterpret_cast(qkv_ptr),