Files
FastDeploy/custom_ops/gpu_ops/append_attn/append_attention_func.cuh

2538 lines
92 KiB
Plaintext

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "mem_util.cuh"
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T>
__forceinline__ __device__ float fixed_expf(float x1, float x2) {
if constexpr (std::is_same<T, half>::value) {
if (x1 == -5e4f) {
return 0;
} else {
return __expf(x1 - x2);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
if (x1 == -3.0e+30f) {
return 0;
} else {
return __expf(x1 - x2);
}
}
}
template <size_t vec_size, typename T>
struct prefill_softmax_state_t {
AlignedVector<T, vec_size> o;
float m;
float d;
__device__ __forceinline__ void init() {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = -3.38953e38f;
}
}
__device__ __forceinline__ void merge(
const AlignedVector<T, vec_size>& other_o, float other_m, float other_d) {
float m_prev = m, d_prev = d;
m = m_prev > other_m ? m_prev : other_m;
const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m);
const T scale1_T = static_cast<T>(scale1),
scale2_T = static_cast<T>(scale2);
d = d_prev * scale1 + other_d * scale2;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * scale1_T + other_o[i] * scale2_T;
}
}
__device__ __forceinline__ void normalize() {
const T d_t = static_cast<T>(d);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d_t;
}
}
};
template <typename T, uint32_t num_frags_x, uint32_t num_frags_y>
__device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8],
float (*m)[2],
float (*d)[2]) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
o_frag[fx][fy][reg_id] = 0.f;
}
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
if constexpr (std::is_same<T, half>::value) {
m[fx][j] = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m[fx][j] = -3.0e+30f;
}
d[fx][j] = 1.f;
}
}
}
template <uint32_t group_size,
uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t HEAD_DIM,
typename T>
__device__ __forceinline__ void load_q_global_smem_multi_warps(
T* q_ptr_base,
smem_t* q_smem,
uint32_t q_idx_base,
const uint32_t qo_upper_bound,
const uint32_t qo_n_stride,
const uint32_t qo_h_stride) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim]
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8,
tx % 8); // 4 * 64
const uint32_t tx_offset = tx / 8;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset;
#pragma unroll
const int j = ty;
const uint32_t offset_now = base_offset + j * 4;
const uint32_t n_offset = offset_now / group_size;
const uint32_t h_offset = offset_now % group_size;
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
q_smem_offset_w =
q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo);
q_ptr += 8 * num_elems_per_128b<T>();
}
q_smem_offset_w =
q_smem->advance_offset_by_row<16, num_vecs_per_head>(q_smem_offset_w) -
2 * num_frags_y;
}
}
template <uint32_t group_size,
uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t HEAD_DIM,
typename T>
__device__ __forceinline__ void load_q_global_smem(
T* q_ptr_base,
smem_t* q_smem,
uint32_t q_idx_base,
const uint32_t qo_upper_bound,
const uint32_t qo_n_stride,
const uint32_t qo_h_stride) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim]
smem_t::get_permuted_offset<num_vecs_per_head>(
ty * num_frags_x * 16 + tx / 8, tx % 8); // 4 * 64
const uint32_t tx_offset = tx / 8;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset;
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) {
const uint32_t offset_now = base_offset + j * 4;
const uint32_t n_offset = offset_now / group_size;
const uint32_t h_offset = offset_now % group_size;
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
q_smem_offset_w =
q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo);
q_ptr += 8 * num_elems_per_128b<T>();
}
q_smem_offset_w =
q_smem->advance_offset_by_row<4, num_vecs_per_head>(q_smem_offset_w) -
2 * num_frags_y; // num_frags_y / 4 * 8
}
}
}
template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps(
smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16]
const float sm_scale) {
constexpr int vec_size = 16 / sizeof(T);
using LoadT = AlignedVector<T, vec_size>;
LoadT tmp_vec;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
#pragma unroll
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024;
++i) {
const int offset = i * 1024 + ty * 256 + tx * 8;
Load<T, vec_size>(reinterpret_cast<T*>(q_smem->base) + offset, &tmp_vec);
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
tmp_vec[reg_id] *= sm_scale;
}
Store<T, vec_size>(tmp_vec, reinterpret_cast<T*>(q_smem->base) + offset);
}
}
template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale(
smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16]
const float sm_scale) {
constexpr int vec_size = 16 / sizeof(T);
using LoadT = AlignedVector<T, vec_size>;
LoadT tmp_vec;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
#pragma unroll
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 256;
++i) { // 32 * 8 per warp
Load<T, vec_size>(
reinterpret_cast<T*>(q_smem->base +
ty * num_frags_x * 16 * num_vecs_per_head) +
i * 256 + tx * 8,
&tmp_vec);
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
tmp_vec[reg_id] *= sm_scale;
}
Store<T, vec_size>(
tmp_vec,
reinterpret_cast<T*>(q_smem->base +
ty * num_frags_x * 16 * num_vecs_per_head) +
i * 256 + tx * 8);
}
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_kv_blockwise(
smem_t smem,
uint32_t* smem_offset,
T** gptr, // [max_block_num, num_heads, block_size, head_dim]
const uint32_t kv_head_idx,
const uint32_t kv_n_stride,
const uint32_t kv_h_stride,
const uint32_t kv_b_stride,
const uint32_t kv_idx_base,
const uint32_t kv_len) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check
#pragma unroll
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps;
++i) {
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4;
++j) {
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
*gptr += 8 * num_elems_per_128b<T>();
}
kv_idx += num_warps * 4;
*smem_offset = smem.advance_offset_by_row<num_warps * 4, num_vecs_per_head>(
*smem_offset) -
2 * num_frags_y; // num_frags_y / 4 * 8
*gptr +=
num_warps * 4 * kv_b_stride - 2 * num_frags_y * num_elems_per_128b<T>();
}
*gptr -= NUM_WARP_KV * num_frags_z * 16 * kv_b_stride;
*smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head;
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename CacheT>
__device__ __forceinline__ void produce_v_blockwise_c8(
smem_t smem,
uint32_t* smem_offset,
CacheT* cache_v,
const int* block_table_now,
const uint32_t kv_head_idx,
const uint32_t kv_n_stride,
const uint32_t kv_h_stride,
const uint32_t kv_d_stride,
const uint32_t kv_idx_base,
const uint32_t kv_len,
const uint32_t const_v_offset) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / num_elems_per_128b<CacheT>(); // 8
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx =
kv_idx_base +
tx % 4 * num_elems_per_128b<CacheT>();
if constexpr (NUM_WARP_Q == 4) {
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset;
#pragma unroll
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
++i) { // m (num_frags_y * 16 / (num_warps * 8))
#pragma unroll
for (uint32_t j = 0; j < num_frags_z / 4;
++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
*smem_offset, j);
cache_v_now += 4 * num_elems_per_128b<CacheT>();
}
*smem_offset =
smem.advance_offset_by_row<num_warps * 8, num_vecs_per_blocksize>(
*smem_offset) -
num_frags_z; // num_frags_z / 4 * 4
cache_v_now += num_warps * 8 * kv_d_stride -
num_frags_z * num_elems_per_128b<CacheT>();
}
*smem_offset -= num_frags_y * 16 * num_vecs_per_blocksize;
} else {
#pragma unroll
for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) {
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset;
#pragma unroll
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
++i) { // m (num_frags_y * 16 / (num_warps * 8))
#pragma unroll
for (uint32_t j = 0; j < 2 * num_frags_z / 4;
++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset =
smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
*smem_offset, j);
cache_v_now += 4 * num_elems_per_128b<CacheT>();
kv_idx += 4 * num_elems_per_128b<CacheT>();
}
kv_idx -= 2 * num_frags_z * num_elems_per_128b<CacheT>();
*smem_offset =
smem.advance_offset_by_row<num_warps * 8, num_vecs_per_blocksize>(
*smem_offset) -
2 * num_frags_z; // num_frags_z / 4 * 4
cache_v_now += num_warps * 8 * kv_d_stride -
2 * num_frags_z * num_elems_per_128b<CacheT>();
}
kv_idx += block_size;
}
*smem_offset -= NUM_WARP_KV / 2 * num_frags_y * 16 * num_vecs_per_blocksize;
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
}
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
}
}
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename CacheT>
__device__ __forceinline__ void produce_k_blockwise_c8(
smem_t smem,
uint32_t* smem_offset,
CacheT* cache_k,
const int* block_table_now,
const uint32_t kv_head_idx,
const uint32_t kv_n_stride,
const uint32_t kv_h_stride,
const uint32_t kv_b_stride,
const uint32_t kv_idx_base,
const uint32_t kv_len,
const uint32_t const_k_offset) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head =
head_dim / num_elems_per_128b<CacheT>(); // 8
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8;
if constexpr (NUM_WARP_Q == 4) {
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset;
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 4)
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 8;
++j) { // k num_frags_y * 16 / 8 / num_elems_per_128b<CacheT>()
// smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx <
// kv_len);
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
*smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>(
*smem_offset, j);
cache_k_now += 8 * num_elems_per_128b<CacheT>();
}
// kv_idx += num_warps * 4;
*smem_offset =
smem.advance_offset_by_row<num_warps * 4, num_vecs_per_head>(
*smem_offset) -
num_frags_y; // num_frags_y / 4 * 4
cache_k_now += num_warps * 4 * kv_b_stride -
num_frags_y * num_elems_per_128b<CacheT>();
}
*smem_offset -= num_frags_z * 16 * num_vecs_per_head;
} else {
#pragma unroll
for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) {
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset;
#pragma unroll
for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 4)
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 8;
++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
*smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>(
*smem_offset, j);
cache_k_now += 8 * num_elems_per_128b<CacheT>();
}
kv_idx += num_warps * 4;
*smem_offset =
smem.advance_offset_by_row<num_warps * 4, num_vecs_per_head>(
*smem_offset) -
num_frags_y; // num_frags_y / 4 * 4
cache_k_now += num_warps * 4 * kv_b_stride -
num_frags_y * num_elems_per_128b<CacheT>();
}
}
*smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head;
}
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename CacheT>
__device__ __forceinline__ void produce_v_blockwise_c4(
smem_t smem,
uint32_t* smem_offset,
CacheT* cache_v,
const int* block_table_now,
const uint32_t kv_head_idx,
const uint32_t kv_n_stride,
const uint32_t kv_h_stride,
const uint32_t kv_d_stride,
const uint32_t kv_idx_base,
const uint32_t kv_len,
const uint32_t const_v_offset) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / 2 / num_elems_per_128b<CacheT>(); // 2
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
uint32_t kv_idx =
kv_idx_base +
tx % 2 * 2 * num_elems_per_128b<CacheT>(); // kv_idx used to check
#pragma unroll
for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV; ++kv_i) {
int block_id = __ldg(&block_table_now[(kv_idx) / block_size]);
if (block_id < 0) block_id = 0;
CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset;
#pragma unroll
for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m
#pragma unroll
for (uint32_t j = 0; j < num_frags_z / 4;
++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>(
*smem_offset, j);
cache_v_now += 2 * num_elems_per_128b<CacheT>();
kv_idx += 4 * num_elems_per_128b<CacheT>();
}
kv_idx -= num_frags_z * num_elems_per_128b<CacheT>();
*smem_offset =
smem.advance_offset_by_row<num_warps * 16, num_vecs_per_blocksize>(
*smem_offset) -
num_frags_z / 2; // num_frags_y / 4 * 2
cache_v_now += num_warps * 16 * kv_d_stride -
num_frags_z / 2 * num_elems_per_128b<CacheT>();
}
kv_idx += block_size;
}
*smem_offset -= NUM_WARP_KV * num_frags_y * 16 * num_vecs_per_blocksize;
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename CacheT>
__device__ __forceinline__ void produce_k_blockwise_c4(
smem_t smem,
uint32_t* smem_offset,
CacheT* cache_k,
const int* block_table_now,
const uint32_t kv_head_idx,
const uint32_t kv_n_stride,
const uint32_t kv_h_stride,
const uint32_t kv_b_stride,
const uint32_t kv_idx_base,
const uint32_t kv_len,
const uint32_t const_k_offset) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head =
head_dim / 2 / num_elems_per_128b<CacheT>(); // 4
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx = kv_idx_base + ty * 8 + tx / 4; // kv_idx used to check
#pragma unroll
for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV; ++kv_i) {
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset;
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 2 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 8)
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 8;
++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>(
*smem_offset, j);
cache_k_now += 4 * num_elems_per_128b<CacheT>();
}
kv_idx += num_warps * 8;
*smem_offset =
smem.advance_offset_by_row<num_warps * 8, num_vecs_per_head>(
*smem_offset) -
num_frags_y / 2; // num_frags_y / 8 * 4
cache_k_now += num_warps * 8 * kv_b_stride -
num_frags_y / 2 * num_elems_per_128b<CacheT>();
}
}
*smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head;
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void block_produce_kv(
smem_t smem,
uint32_t* smem_offset,
T* gptr_base, // [max_block_num, num_heads, block_size, head_dim]
const int* block_table,
const uint32_t kv_head_idx,
const uint32_t kv_n_stride,
const uint32_t kv_h_stride,
const uint32_t kv_b_stride,
const uint32_t kv_idx_base,
const uint32_t kv_len) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 4)
const uint32_t row_now =
kv_idx_base + (i * 4 * num_warps + ty * 4 + tx / 8);
const uint32_t kv_n_idx = row_now / block_size;
const uint32_t kv_bid = row_now % block_size;
T* gptr = gptr_base + __ldg(&block_table[kv_n_idx]) * kv_n_stride +
kv_head_idx * kv_h_stride + kv_bid * kv_b_stride +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4;
++j) { // k num_frags_y * 16 / 8 / num_elems_per_128b<T>()
smem.load_128b_async<fill_mode>(*smem_offset, gptr, row_now < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
gptr += 8 * num_elems_per_128b<T>();
}
*smem_offset =
smem.advance_offset_by_row<num_warps * 4, num_vecs_per_head>(
*smem_offset) -
2 * num_frags_y; // num_frags_y / 4 * 8
}
*smem_offset -= num_frags_z * 16 * num_vecs_per_head;
} else {
const uint32_t row_id_per_tx = tx / 8;
const uint32_t col_id_per_tx = tx % 8;
#pragma unroll
for (uint32_t i = 0; i < num_frags_z;
++i) { // m num_warps * num_frags_z * 16
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
const uint32_t row_now = kv_idx_base + (i * 16 + j * 4 + row_id_per_tx);
const uint32_t kv_n_idx = row_now / block_size;
const uint32_t kv_bid = row_now % block_size;
T* gptr = gptr_base + __ldg(&block_table[kv_n_idx]) * kv_n_stride +
kv_head_idx * kv_h_stride + kv_bid * kv_b_stride +
col_id_per_tx * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y / 4;
++fy) { // k num_frags_y * 16 / 8 / num_elems_per_128b<T>()
smem.load_128b_async<fill_mode>(*smem_offset, gptr, row_now < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, fy);
gptr += 8 * num_elems_per_128b<T>();
}
*smem_offset =
smem.advance_offset_by_row<4, num_vecs_per_head>(*smem_offset) -
2 * num_frags_y; // num_frags_y / 4 * 8
}
}
*smem_offset -= num_frags_z * 16 * num_vecs_per_head;
}
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t num_frags_y,
uint32_t num_frags_z,
typename T>
__device__ __forceinline__ void produce_kv(smem_t smem,
uint32_t* smem_offset,
T** gptr,
const uint32_t kv_n_stride,
const uint32_t kv_idx_base,
const uint32_t kv_len) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 4)
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4;
++j) { // k num_frags_y * 16 / 8 / num_elems_per_128b<T>()
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
*gptr += 8 * num_elems_per_128b<T>();
}
kv_idx += num_warps * 4;
*smem_offset = smem.advance_offset_by_row<num_warps * 4, num_vecs_per_head>(
*smem_offset) -
2 * num_frags_y; // num_frags_y / 4 * 8
*gptr +=
num_warps * 4 * kv_n_stride - 2 * num_frags_y * num_elems_per_128b<T>();
}
*smem_offset -= num_frags_z * 16 * num_vecs_per_head;
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t num_frags_z,
typename T>
__device__ __forceinline__ void compute_qk(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
uint32_t* k_smem_offset_r,
float (*s_frag)[num_frags_z][8]) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
uint32_t a_frag[num_frags_x][4], b_frag[4];
// compute q*k^T
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) { // k
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m
q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]);
*q_smem_offset_r = q_smem->advance_offset_by_row<16, num_vecs_per_head>(
*q_smem_offset_r);
}
*q_smem_offset_r =
q_smem->advance_offset_by_column<2>(*q_smem_offset_r, fy) -
num_frags_x * 16 * num_vecs_per_head;
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) { // n
k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag);
*k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head>(
*k_smem_offset_r);
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
if (fy == 0) {
mma_sync_m16n16k16_row_col_f16f16f32<T, MMAMode::kInit>(
s_frag[fx][fz], a_frag[fx], b_frag);
} else {
mma_sync_m16n16k16_row_col_f16f16f32<T>(
s_frag[fx][fz], a_frag[fx], b_frag);
}
}
}
*k_smem_offset_r =
k_smem->advance_offset_by_column<2>(*k_smem_offset_r, fy) -
num_frags_z * 16 * num_vecs_per_head;
}
*q_smem_offset_r -= num_frags_y * 2;
*k_smem_offset_r -= num_frags_y * 2;
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t num_frags_z,
typename T,
typename CacheT>
__device__ __forceinline__ void compute_qk_c4(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
uint32_t* k_smem_offset_r,
float (*s_frag)[num_frags_z][8],
T (*cache_k_scale_frag)[4],
T (*cache_k_zp_frag)[4]) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
head_dim / 2 / num_elems_per_128b<CacheT>();
uint32_t a_frag[num_frags_x][4][4], b_frag[4], b_frag_dq[4];
#pragma unroll
for (uint32_t ky = 0; ky < num_frags_y / 4; ++ky) { // k
// load q
#pragma unroll
for (uint32_t fy = 0; fy < 4; ++fy) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]);
*q_smem_offset_r =
q_smem->advance_offset_by_row<16, num_vecs_per_head_q>(
*q_smem_offset_r);
}
*q_smem_offset_r =
q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 4 + fy) -
num_frags_x * 16 * num_vecs_per_head_q;
}
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
// load
k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag);
*k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>(
*k_smem_offset_r);
#pragma unroll
for (uint32_t fy = 0; fy < 4; ++fy) {
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_int4(b_frag_dq_T, b_frag[fy]);
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
const int b_offset = b_i % 4;
b_frag_dq_T[b_i] =
(b_frag_dq_T[b_i] - cache_k_zp_frag[ky * 4 + fy][b_offset]) *
cache_k_scale_frag[ky * 4 + fy][b_offset];
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
if (ky == 0 && fy == 0) {
mma_sync_m16n16k16_row_col_f16f16f32<T, MMAMode::kInit>(
s_frag[fx][fz], a_frag[fx][fy], b_frag_dq);
} else {
mma_sync_m16n16k16_row_col_f16f16f32<T>(
s_frag[fx][fz], a_frag[fx][fy], b_frag_dq);
}
}
}
}
*k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>(
*k_smem_offset_r, ky) -
num_frags_z * 16 * num_vecs_per_head_k;
}
// advance by col
*q_smem_offset_r -= num_frags_y * 2;
*k_smem_offset_r -= num_frags_y / 4 * 2;
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t num_frags_z,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
uint32_t* k_smem_offset_r,
const T *cache_k_scale,
float (*s_frag)[num_frags_z][8]) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
head_dim / num_elems_per_128b<CacheT>();
uint32_t a_frag[num_frags_x][2][4], b_frag[4], b_frag_dq[4];
#pragma unroll
for (uint32_t ky = 0; ky < num_frags_y / 2; ++ky) { // k
// load q
#pragma unroll
for (uint32_t fy = 0; fy < 2; ++fy) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]);
*q_smem_offset_r =
q_smem->advance_offset_by_row<16, num_vecs_per_head_q>(
*q_smem_offset_r);
}
*q_smem_offset_r =
q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 2 + fy) -
num_frags_x * 16 * num_vecs_per_head_q;
}
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
// load
k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag);
*k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>(
*k_smem_offset_r);
#pragma unroll
for (uint32_t fy = 0; fy < 2; ++fy) {
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
const int scale_col = (ky * 2 + fy) * 4;
b_frag_dq_T[0] *= cache_k_scale[scale_col];
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_k_scale[scale_col];
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[0];
}
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
if (ky == 0 && fy == 0) {
mma_sync_m16n16k16_row_col_f16f16f32<T, MMAMode::kInit>(
s_frag[fx][fz], a_frag[fx][fy], b_frag_dq);
} else {
mma_sync_m16n16k16_row_col_f16f16f32<T>(
s_frag[fx][fz], a_frag[fx][fy], b_frag_dq);
}
}
}
}
*k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>(
*k_smem_offset_r, ky) -
num_frags_z * 16 * num_vecs_per_head_k;
}
*q_smem_offset_r -= num_frags_y * 2;
*k_smem_offset_r -= num_frags_y / 2 * 2;
}
template <typename T,
bool partition_kv,
bool causal,
uint32_t group_size,
uint32_t num_warps,
uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t num_frags_z,
bool IS_SYSTEM = false>
__device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t qo_idx_base,
const uint32_t kv_idx_base,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
const int32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
if constexpr (!IS_SYSTEM) {
const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 +
8 * ((reg_id % 4) / 2)) /
group_size,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
8 * (reg_id / 4) + reg_id % 2;
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
} else {
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && attn_mask_len > 0 && q_idx < static_cast<uint32_t>(attn_mask_len)) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
}
}
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
} else {
const uint32_t q_idx = qo_idx_base,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
8 * (reg_id / 4) + reg_id % 2;
const bool out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
}
}
}
}
}
template <uint32_t num_frags_x, uint32_t num_frags_y, uint32_t num_frags_z>
__device__ __forceinline__ void update_mdo_states(
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*m)[2],
float (*d)[2]) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
float m_prev = m[fx][j];
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
float m_local =
max(max(s_frag[fx][fz][j * 2 + 0], s_frag[fx][fz][j * 2 + 1]),
max(s_frag[fx][fz][j * 2 + 4], s_frag[fx][fz][j * 2 + 5]));
m[fx][j] = max(m[fx][j], m_local);
}
m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x2, 32));
m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x1, 32));
float o_scale = __expf(m_prev - m[fx][j]);
d[fx][j] *= o_scale;
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
o_frag[fx][fy][j * 2 + 0] *= o_scale;
o_frag[fx][fy][j * 2 + 1] *= o_scale;
o_frag[fx][fy][j * 2 + 4] *= o_scale;
o_frag[fx][fy][j * 2 + 5] *= o_scale;
}
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
s_frag[fx][fz][j * 2 + 0] =
__expf(s_frag[fx][fz][j * 2 + 0] - m[fx][j]);
s_frag[fx][fz][j * 2 + 1] =
__expf(s_frag[fx][fz][j * 2 + 1] - m[fx][j]);
s_frag[fx][fz][j * 2 + 4] =
__expf(s_frag[fx][fz][j * 2 + 4] - m[fx][j]);
s_frag[fx][fz][j * 2 + 5] =
__expf(s_frag[fx][fz][j * 2 + 5] - m[fx][j]);
}
}
}
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t block_size,
typename T,
typename CacheT>
__device__ __forceinline__ void compute_sfm_v_c4(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*d)[2],
T (*cache_v_scale_frag)[2],
T (*cache_v_zp_frag)[2]) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / 2 / num_elems_per_128b<CacheT>();
T s_frag_f16[num_frags_x][num_frags_z][8];
uint32_t b_frag[4], b_frag_dq[4];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
vec_cast<T, float, 8>(s_frag_f16[fx][fz], s_frag[fx][fz]);
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]);
}
}
#pragma unroll
for (uint32_t kz = 0; kz < num_frags_z / 4; ++kz) { // k
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag);
*v_smem_offset_r =
v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>(
*v_smem_offset_r);
#pragma unroll
for (uint32_t fz = 0; fz < 4; ++fz) {
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_int4(b_frag_dq_T, b_frag[fz]);
// scale zp
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
const int b_offset = b_i / 4;
b_frag_dq_T[b_i] =
(b_frag_dq_T[b_i] - cache_v_zp_frag[fy][b_offset]) *
cache_v_scale_frag[fy][b_offset];
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
mma_sync_m16n16k16_row_col_f16f16f32<T>(
o_frag[fx][fy],
(uint32_t*)(s_frag_f16[fx][kz * 4 + fz]),
b_frag_dq);
}
}
}
*v_smem_offset_r =
v_smem->advance_offset_by_column<2, num_vecs_per_blocksize>(
*v_smem_offset_r, kz) -
num_frags_y * 16 * num_vecs_per_blocksize;
}
*v_smem_offset_r -= num_frags_z / 4 * 2;
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
__device__ __forceinline__ void compute_sfm_v_c8(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*d)[2],
const T *cache_v_scale) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / num_elems_per_128b<CacheT>();
T s_frag_f16[num_frags_x][num_frags_z][8];
uint32_t b_frag[4], b_frag_dq[4];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
vec_cast<T, float, 8>(s_frag_f16[fx][fz], s_frag[fx][fz]);
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]);
}
}
#pragma unroll
for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag);
*v_smem_offset_r =
v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>(
*v_smem_offset_r);
#pragma unroll
for (uint32_t fz = 0; fz < 2; ++fz) {
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
mma_sync_m16n16k16_row_col_f16f16f32<T>(
o_frag[fx][fy],
(uint32_t*)(s_frag_f16[fx][kz * 2 + fz]),
b_frag_dq);
}
}
}
*v_smem_offset_r =
v_smem->advance_offset_by_column<2, num_vecs_per_blocksize>(
*v_smem_offset_r, kz) -
num_frags_y * 16 * num_vecs_per_blocksize;
}
*v_smem_offset_r -= num_frags_z / 2 * 2;
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*d)[2],
T *cache_v_scale) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / num_elems_per_128b<CacheT>();
T s_frag_f16[num_frags_x][num_frags_z][8];
uint32_t b_frag[4], b_frag_dq[4];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
vec_cast<T, float, 8>(s_frag_f16[fx][fz], s_frag[fx][fz]);
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]);
}
}
#pragma unroll
for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag);
*v_smem_offset_r =
v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>(
*v_smem_offset_r);
#pragma unroll
for (uint32_t fz = 0; fz < 2; ++fz) {
// dequant b_frag -> b_frag_dq
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
mma_sync_m16n16k16_row_col_f16f16f32<T>(
o_frag[fx][fy],
(uint32_t*)(s_frag_f16[fx][kz * 2 + fz]),
b_frag_dq);
}
}
}
*v_smem_offset_r -= num_frags_y * 16 * num_vecs_per_blocksize;
}
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t num_frags_z,
typename T>
__device__ __forceinline__ void compute_sfm_v(smem_t* v_smem,
uint32_t* v_smem_offset_r,
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*d)[2]) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
T s_frag_f16[num_frags_x][num_frags_z][8];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
vec_cast<T, float, 8>(s_frag_f16[fx][fz], s_frag[fx][fz]);
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]);
}
}
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z;
++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t b_frag[4];
v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag);
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
mma_sync_m16n16k16_row_col_f16f16f32<T>(
o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag);
}
*v_smem_offset_r =
v_smem->advance_offset_by_column<2>(*v_smem_offset_r, fy);
}
*v_smem_offset_r =
v_smem->advance_offset_by_row<16, num_vecs_per_head>(*v_smem_offset_r) -
2 * num_frags_y;
}
*v_smem_offset_r -= 16 * num_frags_z * num_vecs_per_head;
}
template <uint32_t num_frags_x, uint32_t num_frags_y>
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
float (*d)[2]) {
float d_rcp[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
d_rcp[fx][j] = 1.f / d[fx][j];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
o_frag[fx][fy][reg_id] =
o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2];
}
}
}
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t NUM_WARPS,
typename T>
__device__ __forceinline__ void merge_res_multi_warps(
T* o_smem, // [num_threads, num_frags_x, num_frags_y, 8]
T* md_smem, // [num_warps, num_frags_x * 16 * 2]
T (*o_frag)[num_frags_y][8],
T (*m_frag)[2],
T (*d_frag)[2]) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
const uint32_t tidx = ty * 32 + tx;
const uint32_t row_id = tx / 4;
// [num_frags_x * 16, num_frags_y * 16]
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
const int offset =
tidx * num_frags_x * num_frags_y * 8 + fx * num_frags_y * 8 + fy * 8;
*(b128_t*)(&o_smem[offset]) = *(b128_t*)&o_frag[fx][fy];
*(b128_t*)(&o_smem[offset + 4]) = *(b128_t*)&o_frag[fx][fy][4];
}
}
if (tx % 4 == 0) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
const int offset = ty * num_frags_x * 16 + fx * 16 + row_id;
md_smem[offset * 2] = m_frag[fx][0];
md_smem[offset * 2 + 1] = d_frag[fx][0];
md_smem[(offset + 8) * 2] = m_frag[fx][1];
md_smem[(offset + 8) * 2 + 1] = d_frag[fx][1];
}
}
__syncthreads();
if (ty == 0) {
#pragma unroll
for (uint32_t warp_id = 0; warp_id < NUM_WARPS; ++warp_id) {
const int tmp_tidx = warp_id * 32 + tx;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
const int offset = warp_id * num_frags_x * 16 + fx * 16 + row_id;
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
const int o_offset = tmp_tidx * num_frags_x * num_frags_y * 8 +
fx * num_frags_y * 8 + fy * 8;
AlignedVector<float, 8> o_now;
Load(&o_smem[o_offset], &o_now);
float m_prev = m_frag[fx][0], d_prev = d_frag[fx][0];
float m_now = md_smem[offset * 2], d_now = md_smem[offset * 2 + 1];
float tmp_m = max(m_prev, m_now);
float scale1 = __expf(m_prev - tmp_m), scale2 = __expf(m_now - tmp_m);
float tmp_d = scale1 * d_prev + scale2 * d_now;
o_frag[fx][fx][0] = scale1 * o_frag[fx][fx][0] + scale2 * o_now[0];
o_frag[fx][fx][1] = scale1 * o_frag[fx][fx][1] + scale2 * o_now[1];
o_frag[fx][fx][4] = scale1 * o_frag[fx][fx][4] + scale2 * o_now[4];
o_frag[fx][fx][5] = scale1 * o_frag[fx][fx][5] + scale2 * o_now[5];
m_frag[fx][0] = tmp_m;
d_frag[fx][0] = tmp_d;
m_prev = m_frag[fx][1], d_prev = d_frag[fx][1];
m_now = md_smem[(offset + 8) * 2],
d_now = md_smem[(offset + 8) * 2 + 1];
tmp_m = max(m_prev, m_now);
scale1 = __expf(m_prev - tmp_m), scale2 = __expf(m_now - tmp_m);
tmp_d = scale1 * d_prev + scale2 * d_now;
o_frag[fx][fx][2] = scale1 * o_frag[fx][fx][2] + scale2 * o_now[2];
o_frag[fx][fx][3] = scale1 * o_frag[fx][fx][3] + scale2 * o_now[3];
o_frag[fx][fx][6] = scale1 * o_frag[fx][fx][6] + scale2 * o_now[6];
o_frag[fx][fx][7] = scale1 * o_frag[fx][fx][7] + scale2 * o_now[7];
m_frag[fx][1] = tmp_m;
d_frag[fx][1] = tmp_d;
}
}
}
}
}
template <uint32_t group_size,
uint32_t num_frags_x,
uint32_t num_frags_y,
typename T>
__device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
float (*o_frag)[num_frags_y][8],
smem_t* o_smem,
T* o_ptr_base,
uint32_t o_idx_base,
const uint32_t qo_upper_bound,
const uint32_t qo_n_stride,
const uint32_t qo_h_stride) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if (ty == 0) {
// [num_frags_x * 16, num_frags_y * 16]
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
fx * 16 + tx / 4,
fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] =
o_frag_f16[2];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3];
}
}
}
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * 4 + tx / 8, tx % 8);
o_idx_base += (tx / 8) / group_size;
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
((tx / 8) % group_size) * qo_h_stride;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
// for (uint32_t j = 0; j < 4; ++j) { // 4 * 4 = 16
const int j = ty;
const uint32_t o_idx = o_idx_base + (fx * 16 + j * 4) / group_size;
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
((fx * 16 + j * 4) % group_size) * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
if (o_idx < qo_upper_bound) {
// need write
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
o_ptr += 8 * num_elems_per_128b<T>();
o_smem_offset_w =
o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo);
}
o_smem_offset_w =
o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) -
2 * num_frags_y;
}
}
template <typename T, int VEC_SIZE, typename OutT>
struct StoreFunc {
__device__ __forceinline__ void operator()(
const AlignedVector<T, VEC_SIZE>& ori_out_vec,
const AlignedVector<T, VEC_SIZE>& shift_bias_vec,
const AlignedVector<T, VEC_SIZE>& smooth_weight_vec,
AlignedVector<OutT, VEC_SIZE>& out_vec,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int i) {
out_vec[i] = static_cast<OutT>(ori_out_vec[i]);
printf("Fatal! Unimplemented StoreFunc for cascade append attention\n");
}
};
template <typename T, int VEC_SIZE>
struct StoreFunc<T, VEC_SIZE, int8_t> {
__device__ __forceinline__ void operator()(
const AlignedVector<T, VEC_SIZE>& ori_out_vec,
const AlignedVector<T, VEC_SIZE>& shift_bias_vec,
const AlignedVector<T, VEC_SIZE>& smooth_weight_vec,
AlignedVector<int8_t, VEC_SIZE>& out_vec,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int i) {
float quant_value =
127.0f *
static_cast<float>((ori_out_vec[i] + shift_bias_vec[i]) *
smooth_weight_vec[i]) *
in_scale;
quant_value = rintf(quant_value);
quant_value = quant_value > 127.0f ? 127.0f : quant_value;
quant_value = quant_value < -127.0f ? -127.0f : quant_value;
out_vec[i] = static_cast<int8_t>(quant_value);
}
};
template <typename T, int VEC_SIZE>
struct StoreFunc<T, VEC_SIZE, __nv_fp8_e4m3> {
__device__ __forceinline__ void operator()(
const AlignedVector<T, VEC_SIZE>& ori_out_vec,
const AlignedVector<T, VEC_SIZE>& shift_bias_vec,
const AlignedVector<T, VEC_SIZE>& smooth_weight_vec,
AlignedVector<__nv_fp8_e4m3, VEC_SIZE>& out_vec,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int i) {
float quant_value =
quant_max_bound * static_cast<float>(ori_out_vec[i]) * in_scale;
quant_value = quant_value > quant_max_bound ? quant_max_bound : quant_value;
quant_value = quant_value < quant_min_bound ? quant_min_bound : quant_value;
out_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value);
}
};
template <typename T, int VEC_SIZE>
struct StoreFunc<T, VEC_SIZE, T> {
__device__ __forceinline__ void operator()(
const AlignedVector<T, VEC_SIZE>& ori_out_vec,
const AlignedVector<T, VEC_SIZE>& shift_bias_vec,
const AlignedVector<T, VEC_SIZE>& smooth_weight_vec,
AlignedVector<T, VEC_SIZE>& out_vec,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int i) {
out_vec[i] = ori_out_vec[i];
}
};
template <uint32_t group_size,
uint32_t num_frags_x,
uint32_t num_frags_y,
bool partition_kv,
typename T,
typename OutT>
__device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
float (*o_frag)[num_frags_y][8],
smem_t* o_smem,
OutT* o_ptr_base,
const T* shift_bias,
const T* smooth_weight,
uint32_t o_idx_base,
const uint32_t q_head_idx_base,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const uint32_t qo_upper_bound,
const uint32_t qo_n_stride,
const uint32_t qo_h_stride) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr int VEC_SIZE = 16 / sizeof(T);
AlignedVector<T, VEC_SIZE> ori_out_vec;
AlignedVector<T, VEC_SIZE> shift_bias_vec;
AlignedVector<T, VEC_SIZE> smooth_weight_vec;
AlignedVector<OutT, VEC_SIZE> out_vec;
// [num_warps * num_frags_x * 16, num_frags_y * 16]
if (ty == 0) {
// [num_frags_x * 16, num_frags_y * 16]
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
fx * 16 + tx / 4,
fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] =
o_frag_f16[2];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3];
}
}
}
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * 4 + tx / 8, tx % 8);
const uint32_t tx_offset = tx / 8;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset;
#pragma unroll
const int j = ty;
const uint32_t offset_now = base_offset + j * 4;
const uint32_t n_offset = offset_now / group_size;
const uint32_t h_offset = offset_now % group_size;
OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
if (n_offset < qo_upper_bound) {
if constexpr (!partition_kv) {
Load<T, VEC_SIZE>(
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
&ori_out_vec);
if (in_scale > 0.0) {
if (shift_bias) {
Load<T, VEC_SIZE>(shift_bias + shift_smooth_offset,
&shift_bias_vec);
Load<T, VEC_SIZE>(smooth_weight + shift_smooth_offset,
&smooth_weight_vec);
}
}
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
StoreFunc<T, VEC_SIZE, OutT>()(ori_out_vec,
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
}
Store<OutT, VEC_SIZE>(out_vec, o_ptr);
} else {
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
}
o_ptr += 8 * num_elems_per_128b<T>();
shift_smooth_offset += 8 * num_elems_per_128b<T>();
o_smem_offset_w =
o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo);
}
o_smem_offset_w =
o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) -
2 * num_frags_y;
}
}
template <uint32_t group_size,
uint32_t num_frags_x,
uint32_t num_frags_y,
bool partition_kv,
typename T,
typename OutT>
__device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
float (*o_frag)[num_frags_y][8],
smem_t* o_smem,
OutT* o_ptr_base,
const T* shift_bias,
const T* smooth_weight,
uint32_t o_idx_base,
const uint32_t q_head_idx_base,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const uint32_t qo_upper_bound,
const uint32_t qo_n_stride,
const uint32_t qo_h_stride) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr int VEC_SIZE = 8;
AlignedVector<T, VEC_SIZE> ori_out_vec;
AlignedVector<T, VEC_SIZE> shift_bias_vec;
AlignedVector<T, VEC_SIZE> smooth_weight_vec;
AlignedVector<OutT, VEC_SIZE> out_vec;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4,
fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] =
o_frag_f16[2];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3];
}
}
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * num_frags_x * 16 + tx / 8,
tx % 8);
const uint32_t tx_offset = tx / 8;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset;
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) { // 4 * 4 = 16
const uint32_t offset_now = base_offset + j * 4;
const uint32_t n_offset = offset_now / group_size;
const uint32_t h_offset = offset_now % group_size;
OutT* o_ptr =
o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
if (n_offset < qo_upper_bound) {
if (!partition_kv) {
Load<T, VEC_SIZE>(
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
&ori_out_vec);
if (in_scale > 0.0) {
if (shift_bias) {
Load<T, VEC_SIZE>(shift_bias + shift_smooth_offset,
&shift_bias_vec);
Load<T, VEC_SIZE>(smooth_weight + shift_smooth_offset,
&smooth_weight_vec);
}
}
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
StoreFunc<T, VEC_SIZE, OutT>()(ori_out_vec,
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
}
Store<OutT, VEC_SIZE>(out_vec, o_ptr);
} else {
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
}
o_ptr += 8 * num_elems_per_128b<T>();
shift_smooth_offset += 8 * num_elems_per_128b<T>();
o_smem_offset_w =
o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo);
}
o_smem_offset_w =
o_smem->advance_offset_by_row<4, num_vecs_per_head>(o_smem_offset_w) -
2 * num_frags_y;
}
}
}
template <uint32_t group_size,
uint32_t num_frags_x,
uint32_t num_frags_y,
typename T>
__device__ __forceinline__ void write_o_reg_gmem(
float (*o_frag)[num_frags_y][8],
smem_t* o_smem,
T* o_ptr_base,
uint32_t o_idx_base,
const uint32_t qo_upper_bound,
const uint32_t qo_n_stride,
const uint32_t qo_h_stride) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4,
fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] =
o_frag_f16[2];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3];
}
}
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * num_frags_x * 16 + tx / 8,
tx % 8);
o_idx_base += (tx / 8) / group_size;
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
((tx / 8) % group_size) * qo_h_stride;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) { // 4 * 4 = 16
const uint32_t o_idx = o_idx_base + (fx * 16 + j * 4) / group_size;
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
((fx * 16 + j * 4) % group_size) * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
if (o_idx < qo_upper_bound) {
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
o_ptr += 8 * num_elems_per_128b<T>();
o_smem_offset_w =
o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo);
}
o_smem_offset_w =
o_smem->advance_offset_by_row<4, num_vecs_per_head>(o_smem_offset_w) -
2 * num_frags_y;
}
}
}
template <uint32_t GROUP_SIZE>
__global__ void split_q_block(const int* __restrict__ seq_lens_q,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
const uint32_t bsz,
const uint32_t num_rows_per_block) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
const int seq_len = seq_lens_q[bid];
const int loop_times = div_up(seq_len * GROUP_SIZE, num_rows_per_block);
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
gridx += loop_times;
}
*num_blocks_x = gridx;
}
}
template <typename T, int vec_size>
__global__ void merge_multi_chunks_kernel(
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int* __restrict__ seq_lens_q,
const int* __restrict__ seq_lens_kv,
const int* __restrict__ batch_id_per_token,
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
T* __restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_seq_len,
const int num_chunks,
const int num_heads,
const int chunk_size,
const int head_dim) {
const int vid = threadIdx.x, hid = threadIdx.y;
const int qid = blockIdx.x;
const uint32_t bid = batch_id_per_token[qid];
if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) {
return;
}
const int seq_len_kv = seq_lens_kv[bid];
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
float m;
float d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m = -3.0e+30f;
}
#pragma unroll 2
for (int i = 0; i < num_chunks_this_seq; ++i) {
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
float m_prev = m;
float d_prev = d;
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1),
scale2_T = static_cast<T>(scale2);
d = d * scale1 + d_now * scale2;
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
}
}
#pragma unroll
for (int j = 0; j < vec_size; j++) {
res_vec[j] /= d;
}
Store<T, vec_size>(res_vec,
&out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
}
template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8],
float* md_smem,
float (*m)[2],
float (*d)[2],
const uint32_t wid,
const uint32_t tid) {
float2* smem_md = reinterpret_cast<float2*>(md_smem);
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
smem_md[((wid * num_frags_x + fx) * 2 + j) * 32 + tid] =
make_float2(m[fx][j], d[fx][j]);
}
}
__syncthreads();
float o_scale[4][num_frags_x][2];
// deal md/scale
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
float m_new;
float d_new = 1.f;
if constexpr (std::is_same<T, half>::value) {
m_new = -5e4f;
} else {
m_new = -3.0e+30f;
}
#pragma unroll
for (uint32_t i = 0; i < 4; ++i) {
float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid];
float m_prev = m_new, d_prev = d_new;
m_new = max(m_new, md.x);
d_new = d_prev * __expf(m_prev - m_new) + md.y * __expf(md.x - m_new);
}
#pragma unroll
for (uint32_t i = 0; i < 4; ++i) {
float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid];
o_scale[i][fx][j] = __expf(md.x - m_new);
}
m[fx][j] = m_new;
d[fx][j] = d_new;
}
}
__syncthreads();
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
AlignedVector<float, 8> o_new;
#pragma
for (uint32_t o_id = 0; o_id < 8; ++o_id) {
o_new[o_id] = 0.f;
}
*(reinterpret_cast<float4*>(md_smem + (wid * 32 + tid) * 8)) =
*(reinterpret_cast<float4*>(&o_frag[fx][fy][0]));
*(reinterpret_cast<float4*>(md_smem + (wid * 32 + tid) * 8 + 4)) =
*(reinterpret_cast<float4*>(&o_frag[fx][fy][4]));
__syncthreads();
#pragma unroll
for (uint32_t i = 0; i < 4; ++i) {
AlignedVector<float, 8> oi;
Load<float, 8>(md_smem + (i * 32 + tid) * 8, &oi);
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
o_new[reg_id] += oi[reg_id] * o_scale[i][fx][(reg_id % 4) / 2];
}
}
*(reinterpret_cast<float4*>(&o_frag[fx][fy][0])) =
*(reinterpret_cast<float4*>(&o_new[0]));
*(reinterpret_cast<float4*>(&o_frag[fx][fy][4])) =
*(reinterpret_cast<float4*>(&o_new[4]));
__syncthreads();
}
}
}
template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
__device__ __forceinline__ void merge_block_res_v2(
float (*o_frag)[num_frags_y][8],
float* md_smem,
float (*m)[2],
float (*d)[2],
const uint32_t wid,
const uint32_t tid) {
float2* smem_md = reinterpret_cast<float2*>(
md_smem + num_frags_x * num_frags_y * 1024); // 4 * 32 * 8
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
smem_md[((wid * num_frags_x + fx) * 2 + j) * 32 + tid] =
make_float2(m[fx][j], d[fx][j]);
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
*(reinterpret_cast<float4*>(
md_smem + (((wid * num_frags_x + fx) * num_frags_y + fy) * 32 + tid) *
8)) = *(reinterpret_cast<float4*>(&o_frag[fx][fy][0]));
*(reinterpret_cast<float4*>(
md_smem +
(((wid * num_frags_x + fx) * num_frags_y + fy) * 32 + tid) * 8 + 4)) =
*(reinterpret_cast<float4*>(&o_frag[fx][fy][4]));
}
}
__syncthreads();
float o_scale[4][num_frags_x][2];
// deal md/scale
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
float m_new;
float d_new = 1.f;
if constexpr (std::is_same<T, half>::value) {
m_new = -5e4f;
} else {
m_new = -3.0e+30f;
}
#pragma unroll
for (uint32_t i = 0; i < 4; ++i) {
float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid];
float m_prev = m_new, d_prev = d_new;
m_new = max(m_new, md.x);
d_new = d_prev * __expf(m_prev - m_new) + md.y * __expf(md.x - m_new);
}
#pragma unroll
for (uint32_t i = 0; i < 4; ++i) {
float2 md = smem_md[((i * num_frags_x + fx) * 2 + j) * 32 + tid];
o_scale[i][fx][j] = __expf(md.x - m_new);
}
m[fx][j] = m_new;
d[fx][j] = d_new;
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
// num_warps * 32 * 8 each time
AlignedVector<float, 8> o_new;
#pragma
for (uint32_t o_id = 0; o_id < 4; ++o_id) {
*(reinterpret_cast<float2*>(&o_new[o_id * 2])) = make_float2(0.f, 0.f);
}
#pragma unroll
for (uint32_t i = 0; i < 4; ++i) {
AlignedVector<float, 8> oi;
Load<float, 8>(
md_smem +
(((i * num_frags_x + fx) * num_frags_y + fy) * 32 + tid) * 8,
&oi);
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
o_new[reg_id] += oi[reg_id] * o_scale[i][fx][(reg_id % 4) / 2];
}
}
*(reinterpret_cast<float4*>(&o_frag[fx][fy][0])) =
*(reinterpret_cast<float4*>(&o_new[0]));
*(reinterpret_cast<float4*>(&o_frag[fx][fy][4])) =
*(reinterpret_cast<float4*>(&o_new[4]));
}
}
}
template <typename T,
int vec_size,
uint32_t bdy,
uint32_t HEAD_DIM,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void merge_multi_chunks_decoder_kernel(
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT *__restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_seq_len,
const int num_chunks,
const int num_heads,
const int chunk_size,
const int head_dim) {
const int vid = threadIdx.x, ty = threadIdx.y;
const int bid = blockIdx.x, hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
const int start_token_idx = cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) return;
int seq_len_kv = seq_lens_kv[bid];
if (ENABLE_PREFILL) {
seq_len_kv += seq_len_q;
if (seq_len_kv == 0) return;
} else {
if (seq_len_kv == 0) return;
seq_len_kv += seq_len_q;
}
const int seq_len_enc = seq_lens_encoder[bid];
if (seq_len_enc > 0) {
return;
}
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
return;
}
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2 *)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
float m;
float d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m = -3.0e+30f;
}
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset = (bid * num_chunks + i) * num_heads + hid;
float m_prev = m;
float d_prev = d;
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
offset = (bid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1),
scale2_T = static_cast<T>(scale2);
d = d * scale1 + d_now * scale2;
#pragma unroll
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
}
}
// store ty res
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
md_smem[2 * ty] = m;
md_smem[2 * ty + 1] = d;
__syncthreads();
if (ty == 0) {
// merge bdy
prefill_softmax_state_t<vec_size, T> st;
st.init();
#pragma unroll
for (int i = 0; i < bdy; i++) {
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
AlignedVector<T, vec_size> shift_bias_vec;
AlignedVector<T, vec_size> smooth_weight_vec;
AlignedVector<OutT, vec_size> out_vec;
if (shift_bias) {
Load<T, vec_size>(shift_bias + shift_smooth_offset, &shift_bias_vec);
Load<T, vec_size>(smooth_weight + shift_smooth_offset,
&smooth_weight_vec);
}
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
StoreFunc<T, vec_size, OutT>()(
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
}
Store<OutT, vec_size>(
out_vec,
&out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]);
}
}
template <typename T,
int vec_size,
uint32_t bdy,
uint32_t HEAD_DIM,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void merge_multi_chunks_v2_kernel(
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT *__restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_seq_len,
const int num_chunks,
const int num_heads,
const int chunk_size,
const int head_dim,
const int token_num,
const int speculate_max_draft_token_num = 5) {
const int vid = threadIdx.x, ty = threadIdx.y;
const int hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue;
int seq_len_kv = seq_lens_kv[bid];
if (ENABLE_PREFILL) {
seq_len_kv += seq_len_q;
if (seq_len_kv == 0) continue;
const int seq_len_enc = seq_lens_encoder[bid];
if (seq_len_enc <= 0) {
continue;
}
} else {
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;
}
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
}
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2 *)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
float m;
float d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m = -3.0e+30f;
}
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset;
if (ENABLE_PREFILL) {
offset = (qid * num_chunks + i) * num_heads + hid;
} else {
offset =
((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks +
i) *
num_heads +
hid;
}
float m_prev = m;
float d_prev = d;
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
if (ENABLE_PREFILL) {
offset =
(qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
} else {
offset = ((bid * speculate_max_draft_token_num + local_seq_id) *
num_chunks * num_heads +
i * num_heads + hid) *
head_dim +
vid * vec_size;
}
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1),
scale2_T = static_cast<T>(scale2);
d = d * scale1 + d_now * scale2;
#pragma unroll
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
}
}
// store ty res
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
md_smem[2 * ty] = m;
md_smem[2 * ty + 1] = d;
__syncthreads();
if (ty == 0) {
// merge bdy
prefill_softmax_state_t<vec_size, T> st;
st.init();
#pragma unroll
for (int i = 0; i < bdy; i++) {
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
AlignedVector<T, vec_size> shift_bias_vec;
AlignedVector<T, vec_size> smooth_weight_vec;
AlignedVector<OutT, vec_size> out_vec;
if (shift_bias) {
Load<T, vec_size>(shift_bias + shift_smooth_offset, &shift_bias_vec);
Load<T, vec_size>(smooth_weight + shift_smooth_offset,
&smooth_weight_vec);
}
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
StoreFunc<T, vec_size, OutT>()(
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
}
Store<OutT, vec_size>(
out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
}
__syncthreads();
}
}