mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
format code (#4720)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -142,7 +142,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(
|
||||
const uint32_t tx_offset = tx / 8;
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
|
||||
const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset;
|
||||
#pragma unroll
|
||||
const int j = ty;
|
||||
@@ -151,8 +150,7 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(
|
||||
const uint32_t h_offset = offset_now % group_size;
|
||||
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
|
||||
q_smem_offset_w =
|
||||
@@ -171,7 +169,7 @@ template <uint32_t group_size,
|
||||
uint32_t HEAD_DIM,
|
||||
typename T>
|
||||
__device__ __forceinline__ void load_q_global_smem(
|
||||
T* q_ptr_base,
|
||||
const T* q_ptr_base,
|
||||
smem_t* q_smem,
|
||||
uint32_t q_idx_base,
|
||||
const uint32_t qo_upper_bound,
|
||||
@@ -194,10 +192,10 @@ __device__ __forceinline__ void load_q_global_smem(
|
||||
const uint32_t offset_now = base_offset + j * 4;
|
||||
const uint32_t n_offset = offset_now / group_size;
|
||||
const uint32_t h_offset = offset_now % group_size;
|
||||
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
|
||||
const T* q_ptr =
|
||||
q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
|
||||
q_smem_offset_w =
|
||||
@@ -223,8 +221,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps(
|
||||
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024;
|
||||
++i) {
|
||||
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) {
|
||||
const int offset = i * 1024 + ty * 256 + tx * 8;
|
||||
Load<T, vec_size>(reinterpret_cast<T*>(q_smem->base) + offset, &tmp_vec);
|
||||
#pragma unroll
|
||||
@@ -289,11 +286,9 @@ __device__ __forceinline__ void produce_kv_blockwise(
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps;
|
||||
++i) {
|
||||
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_y / 4;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
|
||||
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
|
||||
*gptr += 8 * num_elems_per_128b<T>();
|
||||
@@ -332,9 +327,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
block_size / num_elems_per_128b<CacheT>(); // 8
|
||||
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
uint32_t kv_idx =
|
||||
kv_idx_base +
|
||||
tx % 4 * num_elems_per_128b<CacheT>();
|
||||
uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b<CacheT>();
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
@@ -343,8 +336,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
|
||||
++i) { // m (num_frags_y * 16 / (num_warps * 8))
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_z / 4;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_z / 4; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
|
||||
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
|
||||
*smem_offset, j);
|
||||
@@ -369,8 +361,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
|
||||
++i) { // m (num_frags_y * 16 / (num_warps * 8))
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2 * num_frags_z / 4;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
|
||||
*smem_offset =
|
||||
smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
|
||||
@@ -392,27 +383,28 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
template <uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const T* cache_k_scale_now = cache_k_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
||||
}
|
||||
@@ -427,10 +419,12 @@ __device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const T* cache_k_scale_now = cache_k_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
k_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
@@ -443,20 +437,19 @@ __device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
template <uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
@@ -464,7 +457,9 @@ __device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const T* cache_v_scale_now = cache_v_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
||||
}
|
||||
@@ -481,10 +476,12 @@ __device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const T* cache_v_scale_now = cache_v_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
v_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
@@ -560,8 +557,7 @@ __device__ __forceinline__ void produce_k_blockwise_c8(
|
||||
for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps;
|
||||
++i) { // m num_frags_z * 16 / (num_warps * 4)
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_y / 8;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_y / 8; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
|
||||
*smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>(
|
||||
*smem_offset, j);
|
||||
@@ -614,8 +610,7 @@ __device__ __forceinline__ void produce_v_blockwise_c4(
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_z / 4;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_z / 4; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
|
||||
*smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>(
|
||||
*smem_offset, j);
|
||||
@@ -671,8 +666,7 @@ __device__ __forceinline__ void produce_k_blockwise_c4(
|
||||
for (uint32_t i = 0; i < num_frags_z * 2 / num_warps;
|
||||
++i) { // m num_frags_z * 16 / (num_warps * 8)
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < num_frags_y / 8;
|
||||
++j) {
|
||||
for (uint32_t j = 0; j < num_frags_y / 8; ++j) {
|
||||
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
|
||||
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>(
|
||||
*smem_offset, j);
|
||||
@@ -937,7 +931,7 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
uint32_t* q_smem_offset_r,
|
||||
smem_t* k_smem,
|
||||
uint32_t* k_smem_offset_r,
|
||||
const T *cache_k_scale,
|
||||
const T* cache_k_scale,
|
||||
float (*s_frag)[num_frags_z][8]) {
|
||||
constexpr uint32_t head_dim = num_frags_y * 16;
|
||||
constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b<T>();
|
||||
@@ -973,8 +967,8 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < 2; ++fy) {
|
||||
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
@@ -1036,7 +1030,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t chunk_end,
|
||||
const uint32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr,
|
||||
const int* mask_offset = nullptr,
|
||||
const int sliding_window = 0) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
#pragma unroll
|
||||
@@ -1053,24 +1047,25 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
||||
}
|
||||
else if (sliding_window > 0)
|
||||
{
|
||||
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
out_of_boundary = q_idx < qo_len
|
||||
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
|
||||
kv_idx < mask_offset[q_idx * 2])
|
||||
: true;
|
||||
} else if (sliding_window > 0) {
|
||||
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx -
|
||||
(int)qo_len -
|
||||
sliding_window;
|
||||
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
|
||||
out_of_window || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
} else {
|
||||
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
|
||||
(kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len &&
|
||||
kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
const int32_t mask_idx =
|
||||
q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
bool mask = attn_mask[mask_idx];
|
||||
out_of_boundary |= mask;
|
||||
}
|
||||
@@ -1236,7 +1231,7 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
float (*o_frag)[num_frags_y][8],
|
||||
float (*d)[2],
|
||||
const T *cache_v_scale) {
|
||||
const T* cache_v_scale) {
|
||||
constexpr uint32_t num_vecs_per_blocksize =
|
||||
block_size / num_elems_per_128b<CacheT>();
|
||||
T s_frag_f16[num_frags_x][num_frags_z][8];
|
||||
@@ -1268,8 +1263,8 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < 2; ++fz) {
|
||||
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
@@ -1300,7 +1295,6 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
o_frag[fx][fy],
|
||||
(uint32_t*)(s_frag_f16[fx][kz * 2 + fz]),
|
||||
b_frag_dq);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1328,7 +1322,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
float (*o_frag)[num_frags_y][8],
|
||||
float (*d)[2],
|
||||
T *cache_v_scale) {
|
||||
T* cache_v_scale) {
|
||||
constexpr uint32_t num_vecs_per_blocksize =
|
||||
block_size / num_elems_per_128b<CacheT>();
|
||||
|
||||
@@ -1362,8 +1356,8 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
for (uint32_t fz = 0; fz < 2; ++fz) {
|
||||
// dequant b_frag -> b_frag_dq
|
||||
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
@@ -1372,7 +1366,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
@@ -1431,8 +1425,7 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem,
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z;
|
||||
++fz) {
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t b_frag[4];
|
||||
@@ -1611,10 +1604,9 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t o_frag_f16[4];
|
||||
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
|
||||
num_vecs_per_head>(
|
||||
fx * 16 + tx / 4,
|
||||
fy * 2);
|
||||
uint32_t o_smem_offset_w =
|
||||
smem_t::get_permuted_offset<num_vecs_per_head>(fx * 16 + tx / 4,
|
||||
fy * 2);
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w +
|
||||
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
|
||||
@@ -1627,8 +1619,8 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
ty * 4 + tx / 8, tx % 8);
|
||||
uint32_t o_smem_offset_w =
|
||||
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8, tx % 8);
|
||||
|
||||
o_idx_base += (tx / 8) / group_size;
|
||||
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
|
||||
@@ -1642,8 +1634,7 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
|
||||
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
|
||||
((fx * 16 + j * 4) % group_size) * qo_h_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
if (o_idx < qo_upper_bound) {
|
||||
// need write
|
||||
o_smem->store_128b(o_smem_offset_w, o_ptr);
|
||||
@@ -1658,7 +1649,6 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int VEC_SIZE, typename OutT>
|
||||
struct StoreFunc {
|
||||
__device__ __forceinline__ void operator()(
|
||||
@@ -1717,7 +1707,6 @@ struct StoreFunc<T, VEC_SIZE, __nv_fp8_e4m3> {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T, int VEC_SIZE>
|
||||
struct StoreFunc<T, VEC_SIZE, T> {
|
||||
__device__ __forceinline__ void operator()(
|
||||
@@ -1770,10 +1759,9 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t o_frag_f16[4];
|
||||
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
|
||||
num_vecs_per_head>(
|
||||
fx * 16 + tx / 4,
|
||||
fy * 2);
|
||||
uint32_t o_smem_offset_w =
|
||||
smem_t::get_permuted_offset<num_vecs_per_head>(fx * 16 + tx / 4,
|
||||
fy * 2);
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w +
|
||||
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
|
||||
@@ -1786,8 +1774,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
ty * 4 + tx / 8, tx % 8);
|
||||
uint32_t o_smem_offset_w =
|
||||
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8, tx % 8);
|
||||
|
||||
const uint32_t tx_offset = tx / 8;
|
||||
#pragma unroll
|
||||
@@ -1804,8 +1792,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
|
||||
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
|
||||
tx % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
if (n_offset < qo_upper_bound) {
|
||||
if constexpr (!partition_kv) {
|
||||
Load<T, VEC_SIZE>(
|
||||
@@ -1881,10 +1868,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t o_frag_f16[4];
|
||||
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
|
||||
num_vecs_per_head>(
|
||||
(ty * num_frags_x + fx) * 16 + tx / 4,
|
||||
fy * 2);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
(ty * num_frags_x + fx) * 16 + tx / 4, fy * 2);
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w +
|
||||
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
|
||||
@@ -1897,8 +1882,7 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
|
||||
__syncthreads();
|
||||
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
ty * num_frags_x * 16 + tx / 8,
|
||||
tx % 8);
|
||||
ty * num_frags_x * 16 + tx / 8, tx % 8);
|
||||
|
||||
const uint32_t tx_offset = tx / 8;
|
||||
#pragma unroll
|
||||
@@ -1914,13 +1898,12 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
|
||||
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
|
||||
tx % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
if (n_offset < qo_upper_bound) {
|
||||
if (!partition_kv) {
|
||||
Load<T, VEC_SIZE>(
|
||||
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
|
||||
&ori_out_vec);
|
||||
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
|
||||
&ori_out_vec);
|
||||
if (in_scale > 0.0) {
|
||||
if (shift_bias) {
|
||||
Load<T, VEC_SIZE>(shift_bias + shift_smooth_offset,
|
||||
@@ -1929,16 +1912,16 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
|
||||
&smooth_weight_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
StoreFunc<T, VEC_SIZE, OutT>()(ori_out_vec,
|
||||
shift_bias_vec,
|
||||
smooth_weight_vec,
|
||||
out_vec,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
i);
|
||||
shift_bias_vec,
|
||||
smooth_weight_vec,
|
||||
out_vec,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
i);
|
||||
}
|
||||
Store<OutT, VEC_SIZE>(out_vec, o_ptr);
|
||||
} else {
|
||||
@@ -1979,10 +1962,8 @@ __device__ __forceinline__ void write_o_reg_gmem(
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t o_frag_f16[4];
|
||||
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
|
||||
num_vecs_per_head>(
|
||||
(ty * num_frags_x + fx) * 16 + tx / 4,
|
||||
fy * 2);
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
(ty * num_frags_x + fx) * 16 + tx / 4, fy * 2);
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
|
||||
((uint32_t*)(o_smem->base + o_smem_offset_w +
|
||||
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
|
||||
@@ -1995,8 +1976,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
|
||||
__syncthreads();
|
||||
|
||||
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
ty * num_frags_x * 16 + tx / 8,
|
||||
tx % 8);
|
||||
ty * num_frags_x * 16 + tx / 8, tx % 8);
|
||||
|
||||
o_idx_base += (tx / 8) / group_size;
|
||||
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
|
||||
@@ -2009,8 +1989,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
|
||||
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
|
||||
((fx * 16 + j * 4) % group_size) * qo_h_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
|
||||
++fyo) {
|
||||
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
|
||||
if (o_idx < qo_upper_bound) {
|
||||
o_smem->store_128b(o_smem_offset_w, o_ptr);
|
||||
}
|
||||
@@ -2125,7 +2104,6 @@ __global__ void merge_multi_chunks_kernel(
|
||||
&out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
|
||||
__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8],
|
||||
float* md_smem,
|
||||
@@ -2307,18 +2285,18 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true>
|
||||
__global__ void merge_multi_chunks_decoder_kernel(
|
||||
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
// head_dim]
|
||||
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
|
||||
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
|
||||
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int* __restrict__ seq_lens_q,
|
||||
const int* __restrict__ seq_lens_kv,
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ sinks, // [q_num_heads]
|
||||
OutT* __restrict__ out,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
@@ -2419,8 +2397,14 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
StoreFunc<T, vec_size, OutT>()(
|
||||
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
|
||||
StoreFunc<T, vec_size, OutT>()(st.o,
|
||||
shift_bias_vec,
|
||||
smooth_weight_vec,
|
||||
out_vec,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
i);
|
||||
}
|
||||
Store<OutT, vec_size>(
|
||||
out_vec,
|
||||
@@ -2435,19 +2419,19 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true>
|
||||
__global__ void merge_multi_chunks_v2_kernel(
|
||||
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
// head_dim]
|
||||
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
|
||||
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
|
||||
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int* __restrict__ seq_lens_q,
|
||||
const int* __restrict__ seq_lens_kv,
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ sinks, // [q_num_heads]
|
||||
OutT* __restrict__ out,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
@@ -2464,7 +2448,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
if(bid == -1){
|
||||
if (bid == -1) {
|
||||
continue;
|
||||
}
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
@@ -2486,7 +2470,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
continue;
|
||||
}else if (!ENABLE_PREFILL){
|
||||
} else if (!ENABLE_PREFILL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -2496,12 +2480,12 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2 *)(&res_vec) + i) = make_half2(0, 0);
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
float m;
|
||||
@@ -2581,10 +2565,17 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
Load<T, vec_size>(smooth_weight + shift_smooth_offset,
|
||||
&smooth_weight_vec);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
StoreFunc<T, vec_size, OutT>()(
|
||||
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
|
||||
StoreFunc<T, vec_size, OutT>()(st.o,
|
||||
shift_bias_vec,
|
||||
smooth_weight_vec,
|
||||
out_vec,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
i);
|
||||
}
|
||||
Store<OutT, vec_size>(
|
||||
out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
|
||||
Reference in New Issue
Block a user