mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "multi_head_latent_attention_kernel.h"
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
T m;
|
||||
T d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_t() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
T other_m,
|
||||
T other_d) {
|
||||
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
|
||||
T m_prev = m, d_prev = d;
|
||||
m = m_prev > other_m ? m_prev : other_m;
|
||||
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
|
||||
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1 + other_o[i] * scale2;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
|
||||
struct softmax_state_ts {
|
||||
uint32_t num_tiles_ = num_tiles;
|
||||
AlignedVector<T, vec_size> o[num_tiles];
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_ts() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; i++) {
|
||||
o[tile_id][i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
|
||||
__device__ __forceinline__ void produce_kv(CacheT *smem,
|
||||
CacheT *kv_base_gptr,
|
||||
const int * block_table_smem,
|
||||
const uint32_t seq_offset_gmem,
|
||||
const uint32_t seq_offset_smem,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t tidx,
|
||||
const uint32_t chunk_start,
|
||||
const uint32_t chunk_end) {
|
||||
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
|
||||
if (block_id < 0) {
|
||||
block_id = 0;
|
||||
}
|
||||
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
|
||||
// 8/16 T/int8 each time
|
||||
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
|
||||
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
|
||||
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
|
||||
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
|
||||
seq_offset_gmem < chunk_end
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
|
||||
const CacheT* k_smem,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
const uint32_t gid,
|
||||
const float scale,
|
||||
float *s,
|
||||
softmax_state_ts<vec_size, T, num_tile_v>& st) {
|
||||
const CacheT* smem;
|
||||
AlignedVector<T, vec_size> q_vec;
|
||||
AlignedVector<T, vec_size> k_vec;
|
||||
float m_prev = st.m;
|
||||
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
|
||||
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
if (iter_base + j < iter_bound) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = 0.f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = 0.f;
|
||||
}
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
|
||||
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
|
||||
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
st.m = st.m > s[j] ? st.m : s[j];
|
||||
}
|
||||
|
||||
// T o_scale = hexp(m_prev - st.m);
|
||||
float o_scale = __expf(m_prev - st.m);
|
||||
st.d *= o_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
// s[j] = hexp(s[j] - st.m);
|
||||
s[j] = __expf(s[j] - st.m);
|
||||
st.d += s[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
st.o[tile_id][i] *= o_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_sv(const float *s,
|
||||
const CacheT *base_v_smem,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
softmax_state_ts<vec_size, T, num_tile>& st) {
|
||||
const CacheT* v_smem;
|
||||
AlignedVector<T, vec_size> v_vec;
|
||||
#pragma unroll
|
||||
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
|
||||
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
|
||||
uint32_t tile_id = vid / bdx;
|
||||
#pragma unroll
|
||||
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
|
||||
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
@@ -0,0 +1,560 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decode_attention_func.cuh"
|
||||
|
||||
#define CHECK(call) \
|
||||
do \
|
||||
{ \
|
||||
const cudaError_t error_code = call; \
|
||||
if (error_code != cudaSuccess) \
|
||||
{ \
|
||||
printf("CUDA Error:\n"); \
|
||||
printf(" File: %s\n", __FILE__); \
|
||||
printf(" Line %d:\n", __LINE__); \
|
||||
printf(" Error code:%d\n", error_code); \
|
||||
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
|
||||
exit(1); \
|
||||
} \
|
||||
}while(0)
|
||||
|
||||
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
|
||||
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
|
||||
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cum_offsets,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const float in_scale,
|
||||
const int num_chunks,
|
||||
const int chunk_size,
|
||||
const int max_seq_len,
|
||||
const int num_heads,
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int qid = blockIdx.x, hid = blockIdx.y;
|
||||
const int seq_len_q = seq_lens_q[qid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[qid];
|
||||
if (seq_len_kv == 0) return;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
|
||||
return;
|
||||
}
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ T md_smem[bdy * 2];
|
||||
|
||||
const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]);
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
T m;
|
||||
T d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
// merge per ty
|
||||
#pragma unroll 2
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
|
||||
T m_prev = m;
|
||||
T d_prev = d;
|
||||
const T m_now = multi_m[offset];
|
||||
const T d_now = multi_d[offset];
|
||||
m = m_prev > m_now ? m_prev : m_now;
|
||||
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma once
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
|
||||
// merge bdy
|
||||
softmax_state_t<vec_size, T> st{};
|
||||
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
|
||||
#pragma once
|
||||
for (int i = 0; i < iter_num; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
|
||||
AlignedVector<OutT, vec_size> out_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = static_cast<OutT>(st.o[i]);
|
||||
}
|
||||
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
|
||||
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
|
||||
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
|
||||
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
|
||||
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||
CacheT * __restrict__ cache_v,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cum_offsets,
|
||||
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
const float scale,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
|
||||
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
|
||||
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
|
||||
OutT * __restrict__ out) {
|
||||
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t bid = bidx, gid = threadIdx.y;
|
||||
const uint32_t tidx = threadIdx.x;
|
||||
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
|
||||
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
|
||||
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
|
||||
|
||||
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
|
||||
const uint32_t kv_num_heads = gridDim.z;
|
||||
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
|
||||
|
||||
const int *block_table_now = block_table + bid * max_block_num_per_seq;
|
||||
|
||||
const uint32_t num_chunks = gridDim.y;
|
||||
const uint32_t chunk_id = blockIdx.y;
|
||||
const uint32_t q_len = seq_lens_q[bid];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
|
||||
if (kv_len <= 0) {
|
||||
return;
|
||||
}
|
||||
kv_len += q_len;
|
||||
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
|
||||
const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (chunk_id >= num_chunk_this_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
|
||||
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
|
||||
const uint32_t chunk_len = chunk_end - chunk_start;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
|
||||
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
|
||||
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
using VecT = AlignedVector<T, VEC_SIZE>;
|
||||
VecT q_vec;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
|
||||
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
|
||||
q_vec[i] *= scale;
|
||||
}
|
||||
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
|
||||
}
|
||||
|
||||
|
||||
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
|
||||
uint32_t stage_idx = 0;
|
||||
constexpr int loop_times = DEAL_EACH_TIME / bdy;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_STAGES; ++i) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
|
||||
const uint32_t k_seq_id = chunk_start + k_seq_offset;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
k_seq_id,
|
||||
k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
|
||||
|
||||
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
|
||||
float s[DEAL_EACH_TIME];
|
||||
|
||||
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
|
||||
for (int iter = 0; iter < num_iters; ++iter) {
|
||||
wait_group<NUM_STAGES - 1>();
|
||||
__syncthreads();
|
||||
// compute qk
|
||||
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
|
||||
cu_q_smem,
|
||||
kv_smem,
|
||||
chunk_start + iter * DEAL_EACH_TIME,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
gid,
|
||||
scale,
|
||||
s,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
// compute sv
|
||||
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
|
||||
s,
|
||||
kv_smem,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = j * bdy + gid;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
|
||||
stage_idx * DEAL_EACH_TIME + k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// normize if not partition_kv
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
|
||||
const uint32_t tile_id = vid / bdx;
|
||||
if (!partition_kv || num_chunk_this_seq == 1) {
|
||||
st.normalize(tile_id);
|
||||
}
|
||||
if (partition_kv && num_chunk_this_seq > 1) {
|
||||
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
|
||||
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
tmp_m[head_idx] = st.m;
|
||||
tmp_d[head_idx] = st.d;
|
||||
} else {
|
||||
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
|
||||
void MultiQueryDecoderAttention(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
cudaStream_t &stream,
|
||||
const paddle::Tensor &q,
|
||||
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
|
||||
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float rope_scale,
|
||||
const float rope_theta,
|
||||
const float softmax_scale,
|
||||
const float in_scale,
|
||||
paddle::Tensor *out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
constexpr int num_stages = NUM_STAGE;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
|
||||
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
|
||||
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
|
||||
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
|
||||
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
|
||||
|
||||
constexpr int blocky = GROUP_SIZE;
|
||||
const int gridx = bsz;
|
||||
|
||||
constexpr int num_threads = blockx * blocky;
|
||||
|
||||
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
|
||||
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
|
||||
uint32_t cache_smem_bytes = 0;
|
||||
|
||||
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
|
||||
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
|
||||
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
|
||||
assert(act_blocks_per_sm > 1);
|
||||
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
|
||||
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
|
||||
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
|
||||
|
||||
dim3 grids(gridx, num_chunks, kv_num_heads);
|
||||
dim3 blocks(blockx, blocky);
|
||||
if (num_chunks <= 1) {
|
||||
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
|
||||
cache_vec_size, blockx, blocky>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
} else {
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
|
||||
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
|
||||
constexpr int mblockx = HEAD_DIM_V / vec_size;
|
||||
constexpr int bdy = 256 / mblockx;
|
||||
dim3 grids_merge(bsz, num_heads);
|
||||
dim3 blocks_merge(mblockx, bdy);
|
||||
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
|
||||
in_scale,
|
||||
num_chunks,
|
||||
chunk_size,
|
||||
max_seq_len,
|
||||
num_heads,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
}
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim_qk = meta_data.head_dims;
|
||||
const auto head_dim_v = meta_data.head_dims_v;
|
||||
const float rope_scale = 0.0;
|
||||
const float rope_theta = 0.0;
|
||||
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
|
||||
const uint32_t num_stage = get_cascade_attention_num_stages();
|
||||
const uint32_t num_threads = get_cascade_attention_num_threads();
|
||||
|
||||
DISPATCH_CAUSAL(causal, CAUSAL,
|
||||
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
|
||||
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
||||
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
||||
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
|
||||
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
||||
}
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
291
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
291
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "mla_cache_kernel.cuh"
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
|
||||
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
prefill_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
|
||||
|
||||
if (speculate_decoder) {
|
||||
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
} else {
|
||||
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(prefill_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_OP(decode_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_encoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
|
||||
242
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
242
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
@@ -0,0 +1,242 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mem_util.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void speculate_decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void prefill_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const uint32_t token_idx = linear_index / hidden_size;
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const uint32_t ori_bi = ori_token_idx / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id =
|
||||
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
|
||||
const uint32_t block_offset = ori_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
@@ -25,6 +25,7 @@ struct AppendAttnMetaData {
|
||||
int kv_num_heads;
|
||||
int token_nums;
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
};
|
||||
|
||||
@@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 512: { \
|
||||
constexpr size_t HEAD_DIM = 512; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 576: { \
|
||||
constexpr size_t HEAD_DIM = 576; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the num_stage: ", num_stage); \
|
||||
}
|
||||
|
||||
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
|
||||
@@ -328,10 +375,13 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
|
||||
constexpr size_t cache_bytes = 4; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the cache_type: ", cache_type); \
|
||||
}
|
||||
|
||||
|
||||
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
|
||||
if (deal_each_time == 32) { \
|
||||
if (deal_each_time == 32) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 32; \
|
||||
__VA_ARGS__ \
|
||||
} else if (deal_each_time == 64) { \
|
||||
@@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 128) { \
|
||||
constexpr size_t GROUP_SIZE = 128; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
||||
if (block_shape_q <= 16) { \
|
||||
constexpr size_t BLOCK_SHAPE_Q = 16; \
|
||||
|
||||
@@ -316,6 +316,96 @@ void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
|
||||
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
|
||||
int64_t num_experts);
|
||||
void GetPositionIdsAndMaskEncoderBatch(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& position_ids,
|
||||
const paddle::Tensor& mask_encoder_batch);
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len);
|
||||
|
||||
|
||||
void FusedRotaryPositionEncoding(
|
||||
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
paddle::Tensor& key,
|
||||
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
||||
// head_size]
|
||||
const paddle::Tensor& position_ids, // [num_tokens]
|
||||
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
int head_size,
|
||||
bool is_neox);
|
||||
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder);
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
|
||||
@@ -370,6 +460,14 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
|
||||
paddle::Tensor const &input,
|
||||
paddle::Tensor &scales, float scale_ub);
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTc(
|
||||
paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
float routed_scaling_factor);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
@@ -627,6 +725,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("use_atomic_add"),
|
||||
py::arg("use_fp32_reduce"),
|
||||
py::arg("is_zp_float"));
|
||||
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
|
||||
"get_position_ids_and_mask_encoder_batch function");
|
||||
|
||||
|
||||
/**
|
||||
@@ -653,4 +753,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
||||
"dynamic_per_token_scaled_fp8_quant function",
|
||||
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
||||
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
|
||||
|
||||
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
|
||||
|
||||
m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function");
|
||||
|
||||
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
|
||||
|
||||
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
||||
}
|
||||
|
||||
64
custom_ops/gpu_ops/env.h
Normal file
64
custom_ops/gpu_ops/env.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
inline uint32_t get_decoder_block_shape_q() {
|
||||
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
|
||||
static const uint32_t decoder_block_shape_q =
|
||||
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
|
||||
return decoder_block_shape_q;
|
||||
}
|
||||
|
||||
inline uint32_t get_encoder_block_shape_q() {
|
||||
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
|
||||
static const uint32_t encoder_block_shape_q =
|
||||
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
|
||||
return encoder_block_shape_q;
|
||||
}
|
||||
|
||||
inline uint32_t get_max_partition_size(int bsz) {
|
||||
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
|
||||
static const uint32_t max_partition_size =
|
||||
max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(max_partition_size_env));
|
||||
return max_partition_size;
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_deal_each_time() {
|
||||
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
|
||||
static const uint32_t cascade_attention_deal_each_time =
|
||||
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
|
||||
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_num_stages() {
|
||||
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
|
||||
static const uint32_t cascade_attention_num_stages =
|
||||
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
|
||||
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_num_threads() {
|
||||
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
|
||||
static const uint32_t cascade_attention_num_threads =
|
||||
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
|
||||
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
|
||||
}
|
||||
|
||||
inline bool get_mla_use_tensorcore() {
|
||||
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
|
||||
static const uint32_t mla_use_tensorcore =
|
||||
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
return mla_use_tensorcore != 0 ? true : false;
|
||||
}
|
||||
146
custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Normal file
146
custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding_kernel(
|
||||
T* __restrict__ arr,
|
||||
const T* __restrict__ cos_ptr,
|
||||
const T* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim) {
|
||||
int x_index, y_index;
|
||||
T cos, sin;
|
||||
if (IS_NEOX) {
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = cos_ptr[x_index];
|
||||
sin = sin_ptr[x_index];
|
||||
} else {
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = cos_ptr[x_index / 2];
|
||||
sin = sin_ptr[x_index / 2];
|
||||
}
|
||||
|
||||
const T x = arr[x_index];
|
||||
const T y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
|
||||
template <typename T, bool IS_NEOX>
|
||||
__global__ void apply_rotary_embedding_kernel(
|
||||
T* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
const int* __restrict__ position_ids, // [num_tokens]
|
||||
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int pos = position_ids[token_idx];
|
||||
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const T* cos_ptr = cache_ptr;
|
||||
const T* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void FusedRotaryPositionEncoding(
|
||||
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
paddle::Tensor& key,
|
||||
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
||||
// head_size]
|
||||
const paddle::Tensor& position_ids, // [num_tokens]
|
||||
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
int head_size,
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.dims()[0];
|
||||
int num_heads = query.numel() / num_tokens / head_size;
|
||||
int num_kv_heads = key.numel() / num_tokens / head_size;
|
||||
int rot_dim = cos_sin_cache.dims()[1];
|
||||
int64_t query_stride = num_heads * head_size;
|
||||
int64_t key_stride = num_kv_heads * head_size;
|
||||
|
||||
if (num_tokens > 65535) {
|
||||
PD_THROW(
|
||||
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
|
||||
}
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
|
||||
query.dtype(), "apply_rotary_embedding_kernel", [&] {
|
||||
if (is_neox) {
|
||||
apply_rotary_embedding_kernel<data_t, true>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
key.data<data_t>(),
|
||||
position_ids.data<int>(),
|
||||
cos_sin_cache.data<data_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
apply_rotary_embedding_kernel<data_t, false>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
key.data<data_t>(),
|
||||
position_ids.data<int>(),
|
||||
cos_sin_cache.data<data_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
PD_BUILD_OP(fused_rotary_position_encoding)
|
||||
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
|
||||
.Outputs({"query_out", "key_out"})
|
||||
.Attrs({"head_size: int", "is_neox: bool"})
|
||||
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
|
||||
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));
|
||||
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
||||
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
|
||||
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
|
||||
const int* seq_lens_this_time,
|
||||
int* position_ids, // 输出的一维 position_ids
|
||||
int* mask_encoder_batch,
|
||||
const int bsz) { // 批次大小
|
||||
// 当前线程索引(每个线程对应一个批次)
|
||||
int tid = threadIdx.x;
|
||||
if (tid >= bsz) return;
|
||||
|
||||
// 动态计算当前批次的偏移量
|
||||
int offset = 0;
|
||||
for (int i = 0; i < tid; i++) {
|
||||
offset += seq_lens_encoder[i];
|
||||
if (seq_lens_decoder[i] > 0) {
|
||||
offset += seq_lens_this_time[i];
|
||||
}
|
||||
}
|
||||
|
||||
// 当前批次的 encoder 和 decoder 长度
|
||||
int encoder_len = seq_lens_encoder[tid];
|
||||
int decoder_len = seq_lens_decoder[tid];
|
||||
int seq_len_this_time = seq_lens_this_time[tid];
|
||||
|
||||
// 写入 encoder 的 position_ids
|
||||
for (int i = 0; i < encoder_len; i++) {
|
||||
position_ids[offset + i] = i;
|
||||
mask_encoder_batch[offset + i] = 1;
|
||||
}
|
||||
offset += encoder_len;
|
||||
|
||||
// 写入 decoder 的 position_ids
|
||||
if (decoder_len > 0) {
|
||||
for (int i = 0; i < seq_len_this_time; i++) {
|
||||
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
|
||||
mask_encoder_batch[offset + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void GetPositionIdsAndMaskEncoderBatch(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& position_ids,
|
||||
const paddle::Tensor& mask_encoder_batch) {
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
const_cast<int*>(position_ids.data<int>()),
|
||||
const_cast<int*>(mask_encoder_batch.data<int>()),
|
||||
bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_position_ids_and_mask_encoder_batch)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"position_ids",
|
||||
"mask_encoder_batch"})
|
||||
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
|
||||
.SetInplaceMap({{"position_ids", "position_ids_out"},
|
||||
{"mask_encoder_batch", "mask_encoder_batch_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));
|
||||
@@ -39,10 +39,12 @@ namespace cub = hipcub;
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "env.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "paddle/phi/core/cuda_stream.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/backends/gpu/gpu_info.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -513,3 +515,10 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
inline int GetSMVersion() {
|
||||
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
|
||||
phi::backends::gpu::GetCurrentDeviceId());
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
255
custom_ops/gpu_ops/mla_attn/attention_updater.cuh
Normal file
255
custom_ops/gpu_ops/mla_attn/attention_updater.cuh
Normal file
@@ -0,0 +1,255 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/detail/helper_macros.hpp>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
|
||||
};
|
||||
|
||||
template <int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator& op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Allreduce<2> {
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator& op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& summary, Operator& op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0>& dst,
|
||||
Tensor<Engine1, Layout1>& src, Operator& op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++) {
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& summary, Operator& op) {
|
||||
thread_reduce_<init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& max) {
|
||||
MaxOp<float> max_op;
|
||||
reduce_<init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template <bool init, bool warp_reduce = true, typename Engine0, typename Layout0, typename Engine1,
|
||||
typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& sum) {
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) {
|
||||
quad_allreduce_(sum, sum, sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void apply_exp2(Tensor<Engine0, Layout0>& tensor,
|
||||
Tensor<Engine1, Layout1> const& max) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
auto row_max = max(mi);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = __expf(tensor(mi, ni) - row_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tensor,
|
||||
Tensor<Engine1, Layout1> const& max,
|
||||
const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
auto row_max = max(mi);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// row_max * scale is a constant for each row, so we can use fma here
|
||||
tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int NUM_ROWS_PER_THREAD, bool WITH_SCALE>
|
||||
struct OnlineSoftmax {
|
||||
constexpr static float fill_value = -5e4;
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<NUM_ROWS_PER_THREAD>>{}));
|
||||
TensorT row_max, row_sum, scores_scale;
|
||||
float sm_scale_log2;
|
||||
|
||||
CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) {
|
||||
clear(scores_scale);
|
||||
};
|
||||
|
||||
__forceinline__ __device__ TensorT get_lse() const { return row_sum; }
|
||||
|
||||
template <bool init, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT update(Tensor0& acc_s) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
|
||||
if constexpr (init) {
|
||||
reduce_max</*init=*/true>(scores, row_max);
|
||||
if constexpr (WITH_SCALE) {
|
||||
scale_apply_exp2(scores, row_max, sm_scale_log2);
|
||||
} else {
|
||||
apply_exp2(scores, row_max);
|
||||
}
|
||||
reduce_sum</*init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
} else {
|
||||
// update row_max
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*init=*/false>(scores, row_max);
|
||||
// update scores_scale and scale row_sum
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
if constexpr (WITH_SCALE) {
|
||||
scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2);
|
||||
} else {
|
||||
scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur);
|
||||
}
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
// perform exp2 on scores
|
||||
if constexpr (WITH_SCALE) {
|
||||
scale_apply_exp2(scores, row_max, sm_scale_log2);
|
||||
} else {
|
||||
apply_exp2(scores, row_max);
|
||||
}
|
||||
// update row_sum
|
||||
reduce_sum</*init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
return scores_scale;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor0>
|
||||
__forceinline__ __device__ TensorT finalize(Tensor0& acc_s) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.f / sum;
|
||||
scores_scale(mi) = inv_sum;
|
||||
row_max(mi) *= sm_scale_log2;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template <typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1& acc_o) {
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor1, typename Tensor2>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) {
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale_input(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
235
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Normal file
235
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "cute/tensor.hpp"
|
||||
#include "mla_hopper.cuh"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "batch_mla_with_paged_kv_cache.h"
|
||||
#include "env.h"
|
||||
|
||||
using namespace cute;
|
||||
using namespace mla_attn;
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
struct cascade_type_traits {
|
||||
using type = T;
|
||||
using cutlass_type = T;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::bfloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
using cutlass_type = cutlass::bfloat16_t;;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::float16> {
|
||||
using type = half;
|
||||
using cutlass_type = cutlass::half_t;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::float8_e4m3fn> {
|
||||
using type = __nv_fp8_e4m3;
|
||||
using cutlass_type = cutlass::float_e4m3_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void BatchMLAWithPagedKVCacheKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
using NV_TYPE = typename cascade_type_traits<T>::type;
|
||||
using CUTLASS_TYPE = typename cascade_type_traits<T>::cutlass_type;
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto q_head_num = meta_data.q_num_heads;
|
||||
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
const auto max_block_num = bsz * max_block_num_per_seq;
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
|
||||
|
||||
int q_head_dim = meta_data.head_dims;
|
||||
int k_head_dim = meta_data.head_dims;
|
||||
int v_head_dim = meta_data.head_dims_v;
|
||||
// int num_chunks = max_dec_len / chunk_size;
|
||||
int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
|
||||
O_tmp = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num * v_head_dim));
|
||||
m_tmp = allocator->Allocate(
|
||||
sizeof(float) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
|
||||
d_tmp = allocator->Allocate(
|
||||
sizeof(float) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
|
||||
|
||||
Params<CUTLASS_TYPE, CUTLASS_TYPE, CUTLASS_TYPE, int> params = {};
|
||||
params.Q = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(q.data<T>()));
|
||||
params.KV = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(latent_cache.data<T>()));
|
||||
params.O = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(out->data<T>()));
|
||||
params.O_tmp = reinterpret_cast<CUTLASS_TYPE*>(O_tmp->ptr());
|
||||
params.m = reinterpret_cast<float*>(m_tmp->ptr());
|
||||
params.d = reinterpret_cast<float*>(d_tmp->ptr());
|
||||
params.block_tables = const_cast<int*>(block_tables.data<int>());
|
||||
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
|
||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||
params.num_blocks_x_int = num_blocks_x;
|
||||
params.q_stride_bsz = q_head_num * q_head_dim;
|
||||
params.q_stride_head_num = q_head_dim;
|
||||
params.kv_stride_block_num = block_size * k_head_dim;
|
||||
params.kv_stride_block_size = k_head_dim;
|
||||
params.o_stride_bsz = q_head_num * v_head_dim;
|
||||
params.o_stride_head_num = v_head_dim;
|
||||
params.bsz = bsz;
|
||||
params.token_num = token_num;
|
||||
params.max_seq_len = max_seq_len;
|
||||
params.max_block_num = max_block_num;
|
||||
params.max_block_num_per_seq = max_block_num_per_seq;
|
||||
params.q_num_head = q_head_num;
|
||||
params.qk_head_dim = q_head_dim;
|
||||
params.vo_head_dim = v_head_dim;
|
||||
params.block_size = block_size;
|
||||
params.max_draft_token_num = draft_token_num;
|
||||
params.sm_scale = softmax_scale;
|
||||
params.chunk_size = chunk_size;
|
||||
params.chunk_num = num_chunks;
|
||||
|
||||
if (q_head_dim == 576) {
|
||||
BatchMLAWithPagedKVCacheDispatched<576, 512, NV_TYPE>(
|
||||
params, stream
|
||||
);
|
||||
} else {
|
||||
PD_THROW("error!!! q_head_dim must be 576 !!!\n");
|
||||
}
|
||||
}
|
||||
|
||||
template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
|
||||
template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
69
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h
Normal file
69
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2023 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "append_attn/utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void BatchMLAWithPagedKVCacheKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
175
custom_ops/gpu_ops/mla_attn/epilogue.cuh
Normal file
175
custom_ops/gpu_ops/mla_attn/epilogue.cuh
Normal file
@@ -0,0 +1,175 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
|
||||
#ifndef ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
#define ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "named_barrier.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits>
|
||||
struct CollectiveEpilogue {
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
|
||||
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
|
||||
|
||||
static constexpr int NUM_WARPS = Ktraits::NUM_WARPS;
|
||||
static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp;
|
||||
|
||||
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
|
||||
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
|
||||
using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeTmpT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
|
||||
using StrideTmpT = cute::Shape<int32_t, _1, int32_t, int32_t>;
|
||||
using LayoutTmpT = cute::Layout<ShapeTmpT, StrideTmpT>;
|
||||
|
||||
using ShapeNTMAT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideNTMAT = cute::Shape<int32_t, _1>;
|
||||
using LayoutNTMAT = cute::Layout<ShapeNTMAT, StrideNTMAT>;
|
||||
|
||||
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
|
||||
using TMA_O = decltype(make_tma_copy(
|
||||
GmemTiledCopyOTMA{},
|
||||
make_tensor(make_gmem_ptr(static_cast<DTypeO*>(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{},
|
||||
select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O
|
||||
|
||||
static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v<DTypeO>); // 8
|
||||
static_assert(HEAD_DIM_VO % VEC_SIZE == 0);
|
||||
static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64
|
||||
static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0);
|
||||
static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW;
|
||||
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, DTypeO>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<NUM_ROWS>{}, Int<NUM_THREADS_PER_ROW>{}), LayoutRight{}));
|
||||
using TiledCopyOValLayout =
|
||||
decltype(cute::make_layout(cute::make_shape(_1{}, Int<VEC_SIZE>{}), LayoutRight{}));
|
||||
using TiledCopyO =
|
||||
decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout
|
||||
TiledCopyOValLayout{} // Val layout
|
||||
));
|
||||
struct Arguments {
|
||||
DTypeO* O_ptr;
|
||||
LayoutNTMAT const layout_O;
|
||||
DTypeO* O_ptr_tmp;
|
||||
LayoutNTMAT const layout_O_tmp;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
DTypeO* O_ptr;
|
||||
LayoutNTMAT const layout_O;
|
||||
DTypeO* O_ptr_tmp;
|
||||
LayoutNTMAT const layout_O_tmp;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments_ntma(Arguments const& args) {
|
||||
return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& epilogue_params) {}
|
||||
|
||||
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE,
|
||||
typename TiledMma>
|
||||
CUTLASS_DEVICE void store(Params const& epilogue_params,
|
||||
FrgTensorO const& tOrO,
|
||||
FrgTensorLSE const& lse,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int bsz,
|
||||
const int seq_len_now,
|
||||
const int start_token_idx,
|
||||
const int tile_idx,
|
||||
const int kv_len,
|
||||
const int chunk_size,
|
||||
const int max_draft_token_num,
|
||||
const int o_stride_bsz) {
|
||||
const int num_chunks = cute::ceil_div(kv_len, chunk_size);
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOrO_out = convert_type<DTypeO>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
// make sure gemm done
|
||||
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
|
||||
// r2s
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
// make sure r2s done
|
||||
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
|
||||
TiledCopyO gmem_tiled_copy_O;
|
||||
auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz;
|
||||
Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O);
|
||||
Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{});
|
||||
Tensor cO = make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx)
|
||||
ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx);
|
||||
Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D))
|
||||
Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D))
|
||||
Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D))
|
||||
|
||||
// copy if not out of bound
|
||||
auto predicate_fn = [&](auto coords) {
|
||||
auto s_coords = tOcOGroup(_0{}, coords);
|
||||
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now);
|
||||
};
|
||||
copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
163
custom_ops/gpu_ops/mla_attn/kernel_traits.cuh
Normal file
163
custom_ops/gpu_ops/mla_attn/kernel_traits.cuh
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
|
||||
#define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename MainloopPipeline, typename MainloopPipelineQ, class DTypeQ, class DTypeKV, class DTypeQKAccum, class DTypeOut, class IdType,
|
||||
int BLOCK_SHAPE_KV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutP, class SmemLayoutRow, class SmemLayoutO>
|
||||
struct alignas(16) SharedStorageQKVO {
|
||||
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutP>> smem_p;
|
||||
alignas(16) cute::array_aligned<DTypeQKAccum, cute::cosize_v<SmemLayoutRow>> smem_scale;
|
||||
union {
|
||||
alignas(16) cute::array_aligned<DTypeKV, cute::cosize_v<SmemLayoutK>> smem_kv;
|
||||
alignas(16) cute::array_aligned<DTypeOut, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q;
|
||||
alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv;
|
||||
};
|
||||
};
|
||||
|
||||
template <bool USE_TMA_LOAD_KV_, int HEAD_DIM_QK_, int HEAD_DIM_VO_, int GROUP_SIZE_, int BLOCK_SHAPE_Q_, int BLOCK_SHAPE_KV_,
|
||||
int NUM_STAGES_, typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_, typename NV_TYPE_>
|
||||
struct AttentionKernelTraits {
|
||||
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
using DTypeQKAccum = float;
|
||||
using DTypePVAccum = float;
|
||||
using NV_TYPE = NV_TYPE_;
|
||||
|
||||
|
||||
static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_;
|
||||
static constexpr int GROUP_SIZE = GROUP_SIZE_;
|
||||
static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_;
|
||||
static_assert(BLOCK_SHAPE_Q % 64 == 0);
|
||||
static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_;
|
||||
static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_;
|
||||
static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_;
|
||||
static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK;
|
||||
static_assert(HEAD_DIM_QK % 32 == 0);
|
||||
static_assert(HEAD_DIM_VO % 32 == 0);
|
||||
|
||||
static constexpr int NUM_WARPS = 12;
|
||||
static constexpr int NUM_THREADS = 384;
|
||||
static constexpr int NUM_PRODUCER_THREADS = 128;
|
||||
|
||||
using TileShape_QKD = Shape<Int<BLOCK_SHAPE_Q>, Int<BLOCK_SHAPE_KV>, Int<HEAD_DIM_QK>>;
|
||||
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
|
||||
|
||||
static constexpr int NUM_STAGES = NUM_STAGES_;
|
||||
|
||||
using AtomLayoutQKD = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _1, _1>>;
|
||||
using AtomLayoutPV = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _2, _1>>;
|
||||
using TiledMmaQK = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<DTypeQ, DTypeKV, DTypeQKAccum, TileShape_QKD>(), AtomLayoutQKD{}));
|
||||
using TiledMmaPV = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutPV{}));
|
||||
using TiledMmaPVSS = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutPV{}));
|
||||
|
||||
static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{});
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
|
||||
decltype(cute::get<2>(TileShape_QKD{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})),
|
||||
decltype(cute::get<2>(TileShape_QKD{}))>());
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int<NUM_STAGES>{})));
|
||||
using SmemLayoutVt = decltype(composition(
|
||||
SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}),
|
||||
get<1>(TileShape_QKD{}), Int<NUM_STAGES>{}),
|
||||
Step<_2, _1, _3>{})));
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomV{},
|
||||
make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{})));
|
||||
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVtOneStage = decltype(composition(
|
||||
SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}),
|
||||
get<2>(TileShape_PDV{}), Int<1>{}),
|
||||
Step<_2, _1, _3>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
|
||||
|
||||
using SmemCopyAtom = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeQ>;
|
||||
|
||||
static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32);
|
||||
using SmemLayoutRowOneStage = Layout<Shape<_2, Int<128>>, Stride<_1, _2>>;
|
||||
using SmemLayoutRowTwoStage = Layout<Shape<_2, Int<128>, _2>, Stride<_1, _2, _256>>;
|
||||
using SmemLayoutRow = std::conditional_t<IS_CTA_32, SmemLayoutRowTwoStage, SmemLayoutRowOneStage>;
|
||||
|
||||
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
|
||||
decltype(cute::get<1>(TileShape_QKD{}))>());
|
||||
using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{})));
|
||||
using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int<BLOCK_SHAPE_Q>{}, Int<BLOCK_SHAPE_KV>{}, Int<2>{})));
|
||||
using SmemLayoutP = std::conditional_t<IS_CTA_32, SmemLayoutPSSTwoStage, SmemLayoutPSSOneStage>;
|
||||
|
||||
using MainloopPipelineQ = typename cutlass::PipelineAsync<1>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<1>;
|
||||
using MainloopPipeline =
|
||||
std::conditional_t<USE_TMA_LOAD_KV, typename cutlass::PipelineTmaAsync<NUM_STAGES>,
|
||||
typename cutlass::PipelineAsync<NUM_STAGES>>;
|
||||
using PipelineState = typename cutlass::PipelineState<NUM_STAGES>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<MainloopPipeline, MainloopPipelineQ, DTypeQ, DTypeKV, DTypeQKAccum, DTypeO, IdType, BLOCK_SHAPE_KV,
|
||||
SmemLayoutQ, SmemLayoutK, SmemLayoutP, SmemLayoutRow, SmemLayoutO>;
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif
|
||||
348
custom_ops/gpu_ops/mla_attn/mainloop_load.cuh
Normal file
348
custom_ops/gpu_ops/mla_attn/mainloop_load.cuh
Normal file
@@ -0,0 +1,348 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
|
||||
#define ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "named_barrier.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits, bool CAUSAL>
|
||||
struct CollectiveMainloop {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = float;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
static constexpr int NUM_STAGES = Ktraits::NUM_STAGES;
|
||||
static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK;
|
||||
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
|
||||
|
||||
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(DTypeQ); // 8
|
||||
static_assert(HEAD_DIM_QK % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // 576 512
|
||||
static constexpr int kGmemThreadsPerRow = 64 / kGmemElemsPerLoad; // 8
|
||||
using AlignmentTypeQ = cute::uint_byte_t<static_cast<int>(sizeof(DTypeQ)) * kGmemElemsPerLoad>;
|
||||
using GmemCopyAtomQ = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<AlignmentTypeQ>, DTypeQ>;
|
||||
static constexpr int kNThreadsLoad = Ktraits::NUM_PRODUCER_THREADS;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||
GmemCopyAtomQ{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using GmemLayoutAtomQ = Layout<
|
||||
Shape<Int<Ktraits::NUM_PRODUCER_THREADS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopyQ = decltype(make_tiled_copy(
|
||||
GmemCopyAtomQ{},
|
||||
GmemLayoutAtomQ{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutAtomQ = typename Ktraits::SmemLayoutAtomQ;
|
||||
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
|
||||
using ShapeQT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideQT = cute::Shape<int32_t, _1>;
|
||||
using LayoutQT = cute::Layout<ShapeQT, StrideQT>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeMDT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideMDT = cute::Shape<int32_t, _1>;
|
||||
using LayoutMDT = cute::Layout<ShapeMDT, StrideMDT>;
|
||||
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<DTypeKV const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)), StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_QKD{}),
|
||||
_1{})); // no mcast for KV
|
||||
|
||||
static constexpr bool USE_TMA_LOAD_KV = Ktraits::USE_TMA_LOAD_KV;
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
static constexpr uint32_t TmaTransactionBytesQ =
|
||||
static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<DTypeQ> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesKV =
|
||||
static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<DTypeKV> / 8);
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
LayoutQT layout_Q;
|
||||
LayoutT layout_KV;
|
||||
LayoutMDT layout_MD;
|
||||
DTypeQ const* Q_ptr;
|
||||
DTypeKV const* KV_ptr;
|
||||
DTypeMD const* m_ptr;
|
||||
DTypeMD const* d_ptr;
|
||||
IdType const* kv_block_tables;
|
||||
IdType const* seq_lens_this_time;
|
||||
IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_decoder;
|
||||
IdType const* cumsum_q_seqlens;
|
||||
IdType const* batch_ids;
|
||||
IdType const* tile_ids_per_batch;
|
||||
IdType const* num_blocks_x;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_stride_bsz;
|
||||
int q_stride_head_num;
|
||||
int kv_stride_block_num;
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutQT layout_Q;
|
||||
LayoutT layout_KV;
|
||||
LayoutMDT layout_MD;
|
||||
DTypeQ *Q_ptr;
|
||||
DTypeKV* KV_ptr;
|
||||
DTypeMD* m_ptr;
|
||||
DTypeMD* d_ptr;
|
||||
IdType* kv_block_tables;
|
||||
IdType* seq_lens_this_time;
|
||||
IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_decoder;
|
||||
IdType* cumsum_q_seqlens;
|
||||
IdType* batch_ids;
|
||||
IdType* tile_ids_per_batch;
|
||||
IdType* num_blocks_x;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_stride_bsz;
|
||||
int q_stride_head_num;
|
||||
int kv_stride_block_num;
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
TMA_KV tma_load_KV;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args) {
|
||||
TMA_KV tma_load_KV;
|
||||
if constexpr (USE_TMA_LOAD_KV) {
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(args.KV_ptr), args.layout_KV);
|
||||
tma_load_KV =
|
||||
make_tma_copy(GmemTiledCopyKV{}, mKV, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{});
|
||||
}
|
||||
return {args.layout_Q,
|
||||
args.layout_KV,
|
||||
args.layout_MD,
|
||||
const_cast<DTypeQ*>(args.Q_ptr),
|
||||
const_cast<DTypeKV*>(args.KV_ptr),
|
||||
const_cast<DTypeMD*>(args.m_ptr),
|
||||
const_cast<DTypeMD*>(args.d_ptr),
|
||||
const_cast<IdType*>(args.kv_block_tables),
|
||||
const_cast<IdType*>(args.seq_lens_this_time),
|
||||
const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_decoder),
|
||||
const_cast<IdType*>(args.cumsum_q_seqlens),
|
||||
const_cast<IdType*>(args.batch_ids),
|
||||
const_cast<IdType*>(args.tile_ids_per_batch),
|
||||
const_cast<IdType*>(args.num_blocks_x),
|
||||
args.sm_scale,
|
||||
args.bsz,
|
||||
args.max_block_num,
|
||||
args.max_block_num_per_seq,
|
||||
args.q_stride_bsz,
|
||||
args.q_stride_head_num,
|
||||
args.kv_stride_block_num,
|
||||
args.kv_stride_block_size,
|
||||
args.o_stride_bsz,
|
||||
args.o_stride_head_num,
|
||||
args.chunk_size,
|
||||
args.chunk_num,
|
||||
args.max_draft_token_num,
|
||||
tma_load_KV
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
if constexpr (USE_TMA_LOAD_KV) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_KV.get_tma_descriptor());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_q(Params const& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& shared_storage,
|
||||
const int thread_idx,
|
||||
const int bid) {
|
||||
int start_q_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
int offset_Q = mainloop_params.q_stride_bsz * start_q_token_idx;
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(mainloop_params.Q_ptr + offset_Q), mainloop_params.layout_Q);
|
||||
Tensor gQ =
|
||||
local_tile(mQ, select<0, 2>(TileShape_QKD{}), make_coord(_, _0{}))(_, _, _0{});
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor cQ = cute::make_identity_tensor(gQ.shape());
|
||||
|
||||
GmemTiledCopyQ gmem_tiled_copy_q;
|
||||
auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(thread_idx);
|
||||
Tensor tQgQ = gmem_thr_copy_q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_q.partition_D(sQ);
|
||||
Tensor tQcQ = gmem_thr_copy_q.partition_D(cQ);
|
||||
Tensor tQcQGroup = flatten_1(tQcQ);
|
||||
|
||||
int valid_q_size = mainloop_params.seq_lens_this_time[bid];
|
||||
auto q_predicate_fn = [&](auto coords) {
|
||||
auto s_coords = tQcQGroup(_0{}, coords);
|
||||
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, valid_q_size);
|
||||
};
|
||||
Tensor tQgQiGroup = flatten_1(tQgQ);
|
||||
Tensor tQsQiGroup = flatten_1(tQsQ);
|
||||
|
||||
pipeline_q.producer_acquire(smem_pipe_write_q);
|
||||
copy_if(gmem_tiled_copy_q, q_predicate_fn, tQgQiGroup, tQsQiGroup);
|
||||
pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
|
||||
++smem_pipe_write_q;
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_kv(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_write_kv,
|
||||
SharedStorage& shared_storage,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int tile_idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0);
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(mainloop_params.KV_ptr), mainloop_params.layout_KV);
|
||||
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
|
||||
GmemTiledCopy gmem_tiled_copy_kv;
|
||||
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
Tensor tKgK = gmem_thr_copy_kv.partition_S(gKV);
|
||||
Tensor tKsK = gmem_thr_copy_kv.partition_S(sK);
|
||||
|
||||
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
const int block_idx = kv_block_tables(bid, kv_tile_idx);
|
||||
pipeline_kv.producer_acquire(smem_pipe_write_kv);
|
||||
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, block_idx));
|
||||
Tensor tKsKiGroup =
|
||||
flatten_1(tKsK(_, _, _, smem_pipe_write_kv.index()));
|
||||
copy(gmem_tiled_copy_kv, tKgKiGroup, tKsKiGroup);
|
||||
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
|
||||
++smem_pipe_write_kv;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_kv_tma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_write_kv,
|
||||
SharedStorage& shared_storage,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int tile_idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
|
||||
Tensor mKV = mainloop_params.tma_load_KV.get_tma_tensor(mainloop_params.layout_KV.shape());
|
||||
|
||||
// Prepare the TMA loads
|
||||
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
|
||||
auto [tKgK, tKsK] =
|
||||
tma_partition(mainloop_params.tma_load_KV, _0{}, Layout<_1>{},
|
||||
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
#pragma unroll 2
|
||||
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
const int block_idx = kv_block_tables(bid, kv_tile_idx);
|
||||
pipeline_kv.producer_acquire(smem_pipe_write_kv);
|
||||
copy(mainloop_params.tma_load_KV.with(*pipeline_kv.producer_get_barrier(smem_pipe_write_kv), /*mcast_mask=*/0),
|
||||
tKgK(_, block_idx), tKsK(_, smem_pipe_write_kv.index()));
|
||||
++smem_pipe_write_kv;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_
|
||||
500
custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh
Normal file
500
custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh
Normal file
@@ -0,0 +1,500 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
#define ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include "named_barrier.cuh"
|
||||
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
|
||||
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
|
||||
CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_read_kv,
|
||||
FrgTensorO& tOrO,
|
||||
AttentionUpdater& attention_updater,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int qo_len,
|
||||
const int tile_idx,
|
||||
SharedStorage& shared_storage) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = typename Ktraits::DTypeO;
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutP = typename Ktraits::SmemLayoutP;
|
||||
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
|
||||
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
|
||||
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
|
||||
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); // (bsz * draft_token_num * num_head)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
|
||||
typename Ktraits::TiledMmaQK tiled_mma_qk;
|
||||
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
|
||||
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
|
||||
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
|
||||
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup);
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
|
||||
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
|
||||
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 1) {
|
||||
// consumer 0, compute qk
|
||||
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
|
||||
|
||||
constexpr int n_masking_steps = !CAUSAL ? 1 : cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) + 1;
|
||||
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
|
||||
bool is_first_step = true;
|
||||
// wait q
|
||||
consumer_wait(pipeline_q, smem_pipe_read_q);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
#pragma unroll 1
|
||||
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
|
||||
// wait kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
// gemm qk
|
||||
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
// mask
|
||||
if (masking_step > 0) {
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update s (exp(s - m))
|
||||
Tensor scale_o = is_first_step ? attention_updater.update</*init=*/true>(tSrS) : attention_updater.update</*init=*/false>(tSrS);
|
||||
is_first_step = false;
|
||||
|
||||
Tensor convert_tSrS = convert_type<DTypeKV>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_tSrS);
|
||||
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP);
|
||||
cute::copy(scale_o, tScalesScale);
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
// make sure r2s all done
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
|
||||
// pv gemm
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
// sync WG1 WG2
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
|
||||
}
|
||||
// release q
|
||||
pipeline_q.consumer_release(smem_pipe_read_q);
|
||||
++smem_pipe_read_q;
|
||||
|
||||
// normalize
|
||||
Tensor scale_o = attention_updater.finalize(tSrS); // warp reduce row sum
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cute::copy(scale_o, tScalesScale);
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
|
||||
// WG1 write m,d back to gmem
|
||||
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9
|
||||
const int warp_idx = thread_idx / 32;
|
||||
#pragma unroll
|
||||
for (int w_i = 0; w_i < 2; ++w_i) {
|
||||
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
|
||||
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
|
||||
|
||||
if (token_idx < qo_len) {
|
||||
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
|
||||
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
|
||||
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
|
||||
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
|
||||
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_group_idx == 2) {
|
||||
// consumer 1, compute pv
|
||||
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
|
||||
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
// wait kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
// A: tPsP
|
||||
cute::copy(tScalesScale, scale_o);
|
||||
|
||||
// rescale
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
// sync WG1 WG2
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
|
||||
}
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
|
||||
cute::copy(tScalesScale, scale_o);
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
|
||||
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
|
||||
CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_read_kv,
|
||||
FrgTensorO& tOrO,
|
||||
AttentionUpdater& attention_updater,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int qo_len,
|
||||
const int tile_idx,
|
||||
SharedStorage& shared_storage) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = typename Ktraits::DTypeO; // !!! bf16
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutP = typename Ktraits::SmemLayoutP;
|
||||
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
|
||||
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s3 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 2 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s4 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 3 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
|
||||
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
|
||||
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
|
||||
|
||||
typename Ktraits::TiledMmaQK tiled_mma_qk;
|
||||
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
|
||||
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
|
||||
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
|
||||
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup, _);
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
|
||||
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
|
||||
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrV3 = threadMmaPVSS.partition_fragment_B(sVt_s3);
|
||||
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 1) {
|
||||
// consumer 0, compute qk
|
||||
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
|
||||
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
|
||||
// wait q
|
||||
consumer_wait(pipeline_q, smem_pipe_read_q);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
// wait k
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
// first qk gemm
|
||||
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
// mask
|
||||
{
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor scale_o = attention_updater.update</*init=*/true>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) : 0;
|
||||
--kv_tile_idx;
|
||||
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
PipelineState smem_pipe_read_kv_cur = smem_pipe_read_kv;
|
||||
++smem_pipe_read_kv;
|
||||
// wait next kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
|
||||
// gemm next qk
|
||||
gemm</*init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
attention_updater.rescale_o(tOrO);
|
||||
// last pv gemm
|
||||
if (smem_pipe_read_kv_cur.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv_cur.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv_cur.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
// wait cur qk gemm
|
||||
warpgroup_wait<1>();
|
||||
// mask p
|
||||
if (masking_step > 0) {
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// update s (exp(s - m))
|
||||
Tensor scale_o = attention_updater.update</*init=*/false>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
|
||||
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
// make sure tSrS r2s done
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
// wait last pv gemm
|
||||
warpgroup_wait<0>();
|
||||
// release last kv
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv_cur);
|
||||
}
|
||||
// release q
|
||||
pipeline_q.consumer_release(smem_pipe_read_q);
|
||||
++smem_pipe_read_q;
|
||||
// compute last pv
|
||||
attention_updater.rescale_o(tOrO);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
scale_o = attention_updater.finalize(tSrS);
|
||||
warpgroup_wait<0>();
|
||||
// release last kv
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
|
||||
attention_updater.rescale_o(tOrO);
|
||||
}
|
||||
// WG1 write m,d back to gmem
|
||||
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9
|
||||
const int warp_idx = thread_idx / 32;
|
||||
#pragma unroll
|
||||
for (int w_i = 0; w_i < 2; ++w_i) {
|
||||
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
|
||||
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
|
||||
|
||||
if (token_idx < qo_len) {
|
||||
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
|
||||
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
|
||||
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
|
||||
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
|
||||
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_group_idx == 2) {
|
||||
// consumer 1, compute pv
|
||||
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
|
||||
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
// A: tPsP
|
||||
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
|
||||
// rescale
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
}
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
|
||||
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
575
custom_ops/gpu_ops/mla_attn/mla_hopper.cuh
Normal file
575
custom_ops/gpu_ops/mla_attn/mla_hopper.cuh
Normal file
@@ -0,0 +1,575 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "attention_updater.cuh"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "epilogue.cuh"
|
||||
#include "helper.h"
|
||||
#include "kernel_traits.cuh"
|
||||
#include "mainloop_mma.cuh"
|
||||
#include "mainloop_load.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
|
||||
struct Params {
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *padding_offsets;
|
||||
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
|
||||
uint32_t kv_stride_block_num;
|
||||
uint32_t kv_stride_block_size;
|
||||
|
||||
uint32_t o_stride_bsz;
|
||||
uint32_t o_stride_head_num;
|
||||
|
||||
int bsz;
|
||||
int token_num;
|
||||
int max_seq_len;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_num_head;
|
||||
int qk_head_dim;
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
|
||||
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 64) { \
|
||||
constexpr size_t GROUP_SIZE = 64; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT
|
||||
typename CollectiveEpilogue::Params const epilogue_params) {
|
||||
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
||||
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NUM_MMA_THREADS;
|
||||
} else {
|
||||
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
|
||||
}
|
||||
|
||||
PipelineParamsQ pipeline_params_q;
|
||||
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
|
||||
: MainloopPipelineQ::ThreadCategory::Consumer;
|
||||
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
||||
|
||||
|
||||
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
|
||||
MainloopPipeline pipeline_kv = [&] {
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
|
||||
/*cluster_shape=*/Shape<_1, _1, _1>{});
|
||||
} else {
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
|
||||
}
|
||||
}();
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue;
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
// producer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<72>();
|
||||
}
|
||||
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// consumer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_alloc<216>();
|
||||
}
|
||||
PipelineStateQ smem_pipe_read_q;
|
||||
PipelineState smem_pipe_read_kv;
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
using DTypeKV = typename KernelTraits::DTypeKV;
|
||||
using DTypeO = typename KernelTraits::DTypeO;
|
||||
using IdType = typename KernelTraits::IdType;
|
||||
using NV_TYPE = typename KernelTraits::NV_TYPE;
|
||||
|
||||
using CollectiveMainloop =
|
||||
CollectiveMainloop<KernelTraits, CAUSAL>;
|
||||
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
|
||||
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
|
||||
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
|
||||
params.Q,
|
||||
params.KV,
|
||||
params.m,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
params.max_block_num_per_seq,
|
||||
params.q_stride_bsz,
|
||||
params.q_stride_head_num,
|
||||
params.kv_stride_block_num,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
|
||||
params.O,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
|
||||
params.O_tmp,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
||||
});
|
||||
|
||||
// Get the ptr to kernel function.
|
||||
auto kernel =
|
||||
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
|
||||
int smem_size = sizeof(typename KernelTraits::SharedStorage);
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiprocessor_count;
|
||||
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
||||
int act_blocks_per_sm;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
mainloop_params, epilogue_params
|
||||
);
|
||||
if (params.chunk_num > 1) {
|
||||
constexpr int vec_size = 16 / sizeof(DTypeO);
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.padding_offsets,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.max_seq_len,
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
/*BLOCK_SHAPE_Q_=*/64,
|
||||
/*BLOCK_SHAPE_KV_=*/64,
|
||||
/*NUM_STAGES_=*/2,
|
||||
typename Params::DTypeQ,
|
||||
typename Params::DTypeKV,
|
||||
typename Params::DTypeO,
|
||||
typename Params::IdType,
|
||||
NV_TYPE>,
|
||||
CAUSAL,
|
||||
Params,
|
||||
USE_REG_EALLOC,
|
||||
USE_FIXED_BLOCK>(params, stream);)
|
||||
} else {
|
||||
return cudaErrorNotSupported;
|
||||
}
|
||||
return cudaSuccess;
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
47
custom_ops/gpu_ops/mla_attn/named_barrier.cuh
Normal file
47
custom_ops/gpu_ops/mla_attn/named_barrier.cuh
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
#define ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/arch/barrier.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
enum class NamedBarriers {
|
||||
kQueryEmpty = 0,
|
||||
kValueEmpty = 1,
|
||||
kWarpSchedulerWG1 = 2,
|
||||
kWarpSchedulerWG2 = 3,
|
||||
kWarpSchedulerWG3 = 4,
|
||||
kPrefetchIndices = 5,
|
||||
kOdone = 6,
|
||||
kWG1WG2Sync = 7,
|
||||
kWG0WG1WG2Sync = 8,
|
||||
kWG1WG2LastSync = 9,
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
351
custom_ops/gpu_ops/mla_attn/utils.cuh
Normal file
351
custom_ops/gpu_ops/mla_attn/utils.cuh
Normal file
@@ -0,0 +1,351 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_UTILS_CUH_
|
||||
#define ATTENTION_HOPPER_UTILS_CUH_
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename TensorT>
|
||||
CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) {
|
||||
Tensor tensor_flatten = cute::flatten(tensor);
|
||||
return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride,
|
||||
int64_t h_stride) {
|
||||
return make_layout(make_shape(nnz, head_dim, num_heads),
|
||||
make_stride(n_stride, cute::_1{}, h_stride));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) {
|
||||
return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads)));
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
|
||||
int head_idx, int offset, int seq_len) {
|
||||
auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(offset, _0{}));
|
||||
auto g_sequence =
|
||||
make_tensor(g_offset.data(),
|
||||
make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride()));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
|
||||
int head_idx, int offset, int seq_len) {
|
||||
auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset));
|
||||
|
||||
auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len),
|
||||
cute::make_shape(shape<0>(m_tensor))));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V,
|
||||
// MMA_N))
|
||||
template <typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
|
||||
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16,
|
||||
// MMA_N))
|
||||
template <typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout),
|
||||
make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
};
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const& tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <bool init = false, int wg_wait = 0, typename TensorA, typename TensorB, typename TensorC,
|
||||
typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB,
|
||||
TensorC& tCrC) {
|
||||
constexpr bool Is_RS =
|
||||
!cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
warpgroup_arrive();
|
||||
if constexpr (init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
if constexpr (wg_wait >= 0) {
|
||||
warpgroup_wait<wg_wait>();
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
|
||||
}
|
||||
}
|
||||
|
||||
#define HOSTDEVICE __host__ __device__
|
||||
|
||||
template <typename T, int Size>
|
||||
struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
|
||||
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
|
||||
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
|
||||
};
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
|
||||
const AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
|
||||
*vec = *addr_vec;
|
||||
}
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
|
||||
AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<AlignedVector<T, Size>*>(addr);
|
||||
*addr_vec = vec;
|
||||
}
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct prefill_softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
const float other_m,
|
||||
const float other_d) {
|
||||
float m_prev = m, d_prev = d;
|
||||
m = max(m_prev, other_m);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m);
|
||||
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1_T + other_o[i] * scale2_T;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
const T d_t = static_cast<T>(d);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d_t;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int * __restrict__ padding_offsets,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int max_seq_len,
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
const int max_draft_token_num) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = ori_token_id % max_seq_len;
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
}
|
||||
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
float m;
|
||||
float d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
m = -3.0e+30f;
|
||||
}
|
||||
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset;
|
||||
offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid;
|
||||
float m_prev = m;
|
||||
float d_prev = d;
|
||||
const float m_now = multi_m[offset];
|
||||
const float d_now = multi_d[offset];
|
||||
m = max(m_prev, m_now);
|
||||
offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
||||
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
if (ty == 0) {
|
||||
// merge bdy
|
||||
prefill_softmax_state_t<vec_size, T> st;
|
||||
st.init();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < bdy; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
Store<T, vec_size>(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_UTILS_CUH_
|
||||
@@ -1255,8 +1255,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
||||
469
custom_ops/gpu_ops/multi_head_latent_attention.cu
Normal file
469
custom_ops/gpu_ops/multi_head_latent_attention.cu
Normal file
@@ -0,0 +1,469 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "append_attn/multi_head_latent_attention_kernel.h"
|
||||
#include "mla_attn/batch_mla_with_paged_kv_cache.h"
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
|
||||
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||
|
||||
const bool mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
auto sm_version = GetSMVersion();
|
||||
if ((speculate_decoder || mla_use_tensorcore) && sm_version < 90) {
|
||||
PD_THROW("Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm < 90.");
|
||||
}
|
||||
|
||||
auto main_stream = query.stream();
|
||||
|
||||
paddle::Tensor fmha_out = paddle::full(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
|
||||
0,
|
||||
D,
|
||||
query.place());
|
||||
|
||||
if (max_dec_len_this_time_data > 0) {
|
||||
if (mla_use_tensorcore) {
|
||||
BatchMLAWithPagedKVCacheKernel<data_t>(meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
attn_mask,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
cache_quant_type_str,
|
||||
decoder_num_blocks_data,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
main_stream,
|
||||
&fmha_out);
|
||||
} else {
|
||||
DecodeMLAAttentionKernel<data_t>(
|
||||
meta_data,
|
||||
query, // [token_num, num_heads, head_dim]
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time, // q_seq_len is 1
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
out_linear_in_scale,
|
||||
causal,
|
||||
main_stream,
|
||||
&fmha_out);
|
||||
}
|
||||
}
|
||||
return {fmha_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& query_dims = query.dims();
|
||||
const auto& key_cache_dims = key_cache.dims();
|
||||
const int q_hidden_size = query_dims[query_dims.size() - 1];
|
||||
meta_data.token_nums = query_dims[0];
|
||||
meta_data.kv_num_heads = key_cache_dims[1];
|
||||
meta_data.head_dims = key_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
|
||||
switch (query.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return MultiHeadLatentAttentionKernel<paddle::DataType::BFLOAT16>(
|
||||
meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
query_bias,
|
||||
query_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
cache_quant_type_str,
|
||||
max_input_length,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return MultiHeadLatentAttentionKernel<paddle::DataType::FLOAT16>(
|
||||
meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
query_bias,
|
||||
query_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
cache_quant_type_str,
|
||||
max_input_length,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
||||
const std::vector<int64_t>& query_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& padding_offsets_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
|
||||
const std::vector<int64_t>& max_enc_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_dec_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& query_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& query_out_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const int token_num = query_shape[0];
|
||||
const int kv_num_heads = key_cache_shape[1];
|
||||
const int head_dim_qk = key_cache_shape[3];
|
||||
const int head_dim_v = nope_size;
|
||||
const int q_hidden_size = query_shape[query_shape.size() - 1];
|
||||
const int num_heads = q_hidden_size / head_dim_qk;
|
||||
return {{token_num, num_heads * head_dim_v}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
||||
const paddle::DataType& query_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& padding_offsets_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_cpu_dtype,
|
||||
const paddle::DataType& max_enc_len_this_time_dtype,
|
||||
const paddle::DataType& max_dec_len_this_time_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
const paddle::optional<paddle::DataType>& query_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& query_out_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
if (compute_dtype == "bf16") {
|
||||
return {paddle::DataType::BFLOAT16};
|
||||
} else if (compute_dtype == "fp16") {
|
||||
return {paddle::DataType::FLOAT16};
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(multi_head_latent_attention)
|
||||
.Inputs({"query",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"decoder_num_blocks_cpu",
|
||||
"max_enc_len_this_time",
|
||||
"max_dec_len_this_time",
|
||||
"max_len_kv",
|
||||
paddle::Optional("attn_mask"),
|
||||
paddle::Optional("query_bias"),
|
||||
paddle::Optional("query_out_scales"),
|
||||
paddle::Optional("cache_k_quant_scales"),
|
||||
paddle::Optional("cache_v_quant_scales"),
|
||||
paddle::Optional("cache_k_dequant_scales"),
|
||||
paddle::Optional("cache_v_dequant_scales"),
|
||||
paddle::Optional("cache_k_zp"),
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths")})
|
||||
.Outputs({"fmha_out"})
|
||||
.Attrs({"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"nope_size: int",
|
||||
"max_input_length: int",
|
||||
"softmax_scale: float",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(MultiHeadLatentAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MultiHeadLatentAttentionInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MultiHeadLatentAttentionInferDtype));
|
||||
73
custom_ops/gpu_ops/noaux_tc.cu
Normal file
73
custom_ops/gpu_ops/noaux_tc.cu
Normal file
@@ -0,0 +1,73 @@
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "helper.h"
|
||||
#include "noauxtc_kernel.h"
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
float routed_scaling_factor) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
int64_t num_tokens = input_shape[0];
|
||||
int64_t num_experts = input_shape[1];
|
||||
auto input_type = scores_with_bias.dtype();
|
||||
auto place = scores_with_bias.place();
|
||||
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
|
||||
auto stream = scores_with_bias.stream();
|
||||
|
||||
invokeNoAuxTc<float>(reinterpret_cast<float*>(scores.data<float>()),
|
||||
reinterpret_cast<float*>(group_scores.data<float>()),
|
||||
reinterpret_cast<float*>(scores_with_bias.data<float>()),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
|
||||
return {scores};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> NoauxTcInferDtype(
|
||||
const paddle::DataType& scores_dtype,
|
||||
const paddle::DataType& scores_with_bias_dtype) {
|
||||
return {scores_dtype};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
||||
const std::vector<int64_t>& scores_shape,
|
||||
const std::vector<int64_t>& gating_output_shape) {
|
||||
return {scores_shape};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(noaux_tc)
|
||||
.Inputs({"scores", "scores_with_bias"})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"routed_scaling_factor: float"})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcInferDtype));
|
||||
551
custom_ops/gpu_ops/noauxtc_kernel.h
Normal file
551
custom_ops/gpu_ops/noauxtc_kernel.h
Normal file
@@ -0,0 +1,551 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// This code is partially inspired by and references the implementation found
|
||||
// in NVIDIA TRTLLM.
|
||||
#pragma once
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
return ((len - 1) / size + 1) * size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
return (v && !(v & (v - 1)));
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
|
||||
return max(cache_topk,
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
constexpr int stride = arr_len / 2;
|
||||
for (int i = 0; i < stride; ++i) {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
|
||||
idxT tmp2 = idx_arr[i];
|
||||
idx_arr[i] = idx_arr[other_i];
|
||||
idx_arr[other_i] = tmp2;
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort<32, ascending, T, idxT> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// ascending doesn't matter before merging since all we need is a bitonic
|
||||
// sequence
|
||||
for (int stage = 0; stage < 4; ++stage) {
|
||||
for (int stride = (1 << stage); stride > 0; stride /= 2) {
|
||||
bool reverse = (lane >> stage) & 2;
|
||||
bool is_second = lane & stride;
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge<32, ascending, T, idxT> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
bool is_second = lane & stride;
|
||||
T& val = *val_arr;
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
if (val != other && ((val > other) == (ascending != is_second))) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSort {
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx,
|
||||
idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
if (is_better_than<greater>(t, val_arr_[i])) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out[out_i] = val_arr_[i];
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
idxT idx_arr_[max_arr_len_];
|
||||
|
||||
int const lane_;
|
||||
idxT const k_;
|
||||
T const dummy_;
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
|
||||
int const num_of_warp = blockDim.x / WARP_SIZE;
|
||||
int const warp_id = threadIdx.x / WARP_SIZE;
|
||||
val_smem_ = reinterpret_cast<T*>(smem_buf);
|
||||
val_smem_ += warp_id * WARP_SIZE;
|
||||
idx_smem_ = reinterpret_cast<idxT*>(
|
||||
smem_buf +
|
||||
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
|
||||
idx_smem_ += warp_id * WARP_SIZE;
|
||||
}
|
||||
|
||||
__device__ void add(T const* in, idxT start, idxT end) {
|
||||
idxT const end_for_fullwarp =
|
||||
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
|
||||
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
|
||||
T val = (i < end) ? in[i] : dummy_;
|
||||
add(val, i);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add = is_better_than<greater>(val, k_th_);
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
|
||||
if (do_add && pos < WARP_SIZE) {
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
do_add = false;
|
||||
}
|
||||
smem_buf_len_ += __popc(mask);
|
||||
if (smem_buf_len_ >= WARP_SIZE) {
|
||||
__syncwarp();
|
||||
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
|
||||
smem_buf_len_ -= WARP_SIZE;
|
||||
}
|
||||
if (do_add) {
|
||||
pos -= WARP_SIZE;
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
__device__ void done() {
|
||||
if (smem_buf_len_) {
|
||||
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
|
||||
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
|
||||
merge_buf_(val, idx);
|
||||
}
|
||||
|
||||
// after done(), smem is used for merging results among warps
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
if (is_better_than<greater>(val, old)) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output,
|
||||
T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = cuda::std::numeric_limits<T>::min();
|
||||
T second_largest = cuda::std::numeric_limits<T>::min();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = input[i];
|
||||
if (value > largest) {
|
||||
second_largest = largest;
|
||||
largest = value;
|
||||
} else if (value > second_largest) {
|
||||
second_largest = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
largest = input[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
T max2 = max1;
|
||||
bool equal_to_max1 = (max1 == largest);
|
||||
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
|
||||
|
||||
if (count_max1 == 1) {
|
||||
largest = (largest == max1) ? second_largest : largest;
|
||||
max2 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
*output = max1 + max2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void topk_with_k2_kernel(T* output,
|
||||
T* input,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_cases,
|
||||
int64_t const n_group,
|
||||
int64_t const num_experts_per_group) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
||||
if (case_id < num_cases) {
|
||||
input += case_id * num_experts_per_group;
|
||||
output += case_id;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores,
|
||||
T const* group_scores,
|
||||
T* scores_with_bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
scores_with_bias += case_id * num_experts;
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
if ((n_group > topk_group) && (case_id < num_tokens)) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
|
||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
if (case_id < num_tokens) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates = i < num_experts_per_group
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
}
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value = i < topk ? scores[s_topk_idx[i]]
|
||||
: 0.0f; // Load the valid value of expert
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, value, cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||
scores[s_topk_idx[i]] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeNoAuxTc(T* scores,
|
||||
T* group_scores,
|
||||
T* scores_with_bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_experts,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
double const routed_scaling_factor,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T) \
|
||||
template void invokeNoAuxTc<T>(T * scores, \
|
||||
T * group_scores, \
|
||||
T * scores_with_bias, \
|
||||
int64_t const num_tokens, \
|
||||
int64_t const num_experts, \
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
double const routed_scaling_factor, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float);
|
||||
@@ -50,11 +50,13 @@ __global__ void quant_per_token_per_block(const T *input,
|
||||
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
|
||||
}
|
||||
// get max value per warp
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread);
|
||||
// broadcast max_value
|
||||
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
|
||||
max_value_thread = max(max_value_thread, epsilon);
|
||||
float scale_to_store = max_value_thread / MAX_VALUE;
|
||||
// quant
|
||||
|
||||
Reference in New Issue
Block a user