Files
FastDeploy/custom_ops/gpu_ops/append_attn/utils.cuh
lzy af49b81ffd
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
supports dynamic Cfp8 (#3767)
* supports dynamic Cfp8

* add unittest
2025-09-07 20:41:29 -07:00

612 lines
24 KiB
Plaintext

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include "mem_util.cuh"
struct AppendAttnMetaData {
int batch_size;
int block_size;
int q_num_heads;
int kv_num_heads;
int token_nums;
int head_dims;
int head_dims_v;
int max_blocks_per_seq;
const int *mask_offset = nullptr;
};
__forceinline__ __host__ __device__ int div_up(int a, int b) {
return (a + b - 1) / b;
}
enum PosEncMode { kNonePos, kRoPE, kAliBi };
enum CacheType { CacheT, CacheInt8Hw, CacheInt4CwZp };
template <typename T>
struct cascade_attn_type_traits {
using type = T;
};
template <>
struct cascade_attn_type_traits<phi::dtype::bfloat16> {
using type = __nv_bfloat16;
};
template <>
struct cascade_attn_type_traits<phi::dtype::float16> {
using type = half;
};
template <>
struct cascade_attn_type_traits<phi::dtype::float8_e4m3fn> {
using type = __nv_fp8_e4m3;
};
template <typename T>
struct cascade_attn_nv_type2_traits {
using type = T;
};
template <>
struct cascade_attn_nv_type2_traits<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <>
struct cascade_attn_nv_type2_traits<half> {
using type = half2;
};
template <CacheType cache_type>
struct vec_traits {
using type = b128_t;
};
template <>
struct vec_traits<CacheType::CacheInt8Hw> {
using type = b64_t;
};
template <>
struct vec_traits<CacheType::CacheInt4CwZp> {
using type = b32_t;
};
template <typename T, CacheType cache_type>
struct cache_type_traits {
using type = T;
};
template <typename T>
struct cache_type_traits<T, CacheType::CacheInt8Hw> {
using type = uint8_t;
};
template <typename T>
struct cache_type_traits<T, CacheType::CacheInt4CwZp> {
using type = uint8_t;
};
__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
uint32_t y) {
return (x > y) ? x - y : 0U;
}
/******************************FASTER CAST*********************************/
inline __device__ static void convert_fp8(__nv_bfloat16* result, const uint32_t& source) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
uint32_t dest0;
uint32_t dest1;
asm volatile( \
"{\n" \
".reg .b16 lo, hi;\n" \
"mov.b32 {lo, hi}, %2;\n" \
"cvt.rn.f16x2.e4m3x2 %0, lo;\n" \
"cvt.rn.f16x2.e4m3x2 %1, hi;\n" \
"}\n" : "=r"(dest0), "=r"(dest1) : "r"(source));
((nv_bfloat162*)(result))[0] = __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0]));
((nv_bfloat162*)(result))[1] = __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0]));
#else
printf("Do not support fp8 in arch < 890\n");
asm("trap;");
#endif
}
inline __device__ static void convert_fp8(half* result, const uint32_t& source) {
printf("Do not support fp8 to half although it's very easy.\n");
}
inline __device__ static void convert_int8(
__nv_bfloat16* result, const uint32_t& source) { // 4 int8 each time
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
#pragma unroll
for (int ii = 0; ii < 4; ++ii) {
fp32_intermediates[ii] -= 8388736.f; // (8388608.f + 128.f);
}
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0],
fp32_intermediates_casted[2 * ii + 1],
0x7632);
}
}
inline __device__ static void convert_int8(
half* result, const uint32_t& source) { // 4 int8 each time
uint32_t* fp16_result_ptr = reinterpret_cast<uint32_t*>(result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
static constexpr uint32_t mask_for_elt_01 = 0x5150;
static constexpr uint32_t mask_for_elt_23 = 0x5352;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(fp16_result_ptr[0])
: "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(fp16_result_ptr[1])
: "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(fp16_result_ptr[0])
: "r"(fp16_result_ptr[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(fp16_result_ptr[1])
: "r"(fp16_result_ptr[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
}
inline __device__ static void convert_int4(
__nv_bfloat16* result, const uint32_t& source) { // 8 int4 each time
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(result);
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4
static constexpr uint32_t I4s_TO_FP32s_MAGIC_NUM = 0x43434343;
static constexpr uint32_t mask_for_elt_01 = 0x5150;
static constexpr uint32_t mask_for_elt_23 = 0x5352;
uint32_t tmp1 = source & MASK; // 0 1 2 3
uint32_t tmp2 = source >> 4 & MASK; // 4 5 6 7
bf16_result_ptr[0] = __byte_perm(tmp1,
I4s_TO_FP32s_MAGIC_NUM,
mask_for_elt_01); // 0 1
bf16_result_ptr[1] = __byte_perm(tmp1,
I4s_TO_FP32s_MAGIC_NUM,
mask_for_elt_23); // 2 3
bf16_result_ptr[2] = __byte_perm(tmp2,
I4s_TO_FP32s_MAGIC_NUM,
mask_for_elt_01); // 4 5
bf16_result_ptr[3] = __byte_perm(tmp2,
I4s_TO_FP32s_MAGIC_NUM,
mask_for_elt_23); // 6 7
}
inline __device__ static void convert_int4(
half* result, const uint32_t& source) { // 7 5 3 1 6 4 2 0
uint32_t* fp16_result_ptr = reinterpret_cast<uint32_t*>(result);
static constexpr uint32_t MASK =
0x0f0f0f0f; // 0xf -> 0b1111 select 0,1; 7 5 3 1 6 4 2 0
static constexpr uint32_t I4s_TO_FP32s_MAGIC_NUM = 0x64646464;
static constexpr uint32_t mask_for_elt_01 = 0x5150;
static constexpr uint32_t mask_for_elt_23 = 0x5352;
uint32_t tmp1 = source & MASK; // 0 1 2 3
uint32_t tmp2 = source >> 4 & MASK; // 4 5 6 7
fp16_result_ptr[0] = __byte_perm(tmp1,
I4s_TO_FP32s_MAGIC_NUM,
mask_for_elt_01); // 0 1
fp16_result_ptr[1] = __byte_perm(tmp1,
I4s_TO_FP32s_MAGIC_NUM,
mask_for_elt_23); // 2 3
fp16_result_ptr[2] = __byte_perm(tmp2,
I4s_TO_FP32s_MAGIC_NUM,
mask_for_elt_01); // 4 5
fp16_result_ptr[3] = __byte_perm(tmp2,
I4s_TO_FP32s_MAGIC_NUM,
mask_for_elt_23); // 6 7
}
/******************* vec_t type cast *******************/
template <typename dst_t, typename src_t, size_t vec_size>
__forceinline__ __host__ __device__ void vec_cast(dst_t* dst,
const src_t* src) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
dst[i] = src[i];
}
}
template <size_t vec_size>
__forceinline__ __host__ __device__ void vec_cast<float, half>(
float* dst, const half* src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((float2*)dst)[i] = __half22float2(((half2*)src)[i]);
}
}
template <size_t vec_size>
__forceinline__ __host__ __device__ void vec_cast<half, float>(
half* dst, const float* src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]);
}
}
template <size_t vec_size>
__forceinline__ __host__ __device__ void vec_cast<float, nv_bfloat16>(
float* dst, const nv_bfloat16* src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]);
}
}
template <size_t vec_size>
__forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
nv_bfloat16* dst, const float* src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]);
}
}
#define CHECK_CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
<< ") " << __FILE__ << ": line " << __LINE__ \
<< " at function " << STR(func) << std::endl; \
return e; \
} \
}
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim"); \
} \
}
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
}
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
case 512: { \
constexpr size_t HEAD_DIM = 512; \
__VA_ARGS__ \
break; \
} \
case 576: { \
constexpr size_t HEAD_DIM = 576; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
}
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
if (num_stage == 2) { \
constexpr size_t NUM_STAGE = 2; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the num_stage: ", num_stage); \
}
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
if (cache_type == 0) { \
constexpr CacheType cache_type_now = CacheType::CacheT; \
constexpr size_t cache_bytes = 16; \
__VA_ARGS__ \
} else if (cache_type == 1) { \
constexpr CacheType cache_type_now = CacheType::CacheInt8Hw; \
constexpr size_t cache_bytes = 8; \
__VA_ARGS__ \
} else if (cache_type == 2) { \
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
constexpr size_t cache_bytes = 4; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the cache_type: ", cache_type); \
}
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
if (deal_each_time == 32) { \
constexpr size_t DEAL_EACH_TIME = 32; \
__VA_ARGS__ \
} else if (deal_each_time == 64) { \
constexpr size_t DEAL_EACH_TIME = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the deal_each_time", deal_each_time); \
}
#define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \
if (num_threads == 128) { \
constexpr size_t NUM_THREADS = 128; \
__VA_ARGS__ \
} else if (num_threads == 256) { \
constexpr size_t NUM_THREADS = 256; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the num_threads", num_threads); \
}
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 2) { \
constexpr size_t GROUP_SIZE = 2; \
__VA_ARGS__ \
} else if (group_size == 3) { \
constexpr size_t GROUP_SIZE = 3; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
} else if (group_size == 5) { \
constexpr size_t GROUP_SIZE = 5; \
__VA_ARGS__ \
} else if (group_size == 6) { \
constexpr size_t GROUP_SIZE = 6; \
__VA_ARGS__ \
} else if (group_size == 7) { \
constexpr size_t GROUP_SIZE = 7; \
__VA_ARGS__ \
} else if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \
} else if (group_size == 14) { \
constexpr size_t GROUP_SIZE = 14; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
if (is_dynamic_cfp8) { \
constexpr bool IsDynamicC8 = true; \
__VA_ARGS__ \
} else { \
constexpr bool IsDynamicC8 = false; \
__VA_ARGS__ \
}
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 128) { \
constexpr size_t GROUP_SIZE = 128; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
}
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
if (block_shape_q <= 16) { \
constexpr size_t BLOCK_SHAPE_Q = 16; \
constexpr size_t NUM_WARP_Q = 1; \
__VA_ARGS__ \
} else if (block_shape_q <= 32) { \
constexpr size_t BLOCK_SHAPE_Q = 32; \
constexpr size_t NUM_WARP_Q = 1; \
__VA_ARGS__ \
} else if (block_shape_q <= 64) { \
constexpr size_t BLOCK_SHAPE_Q = 64; \
constexpr size_t NUM_WARP_Q = 4; \
__VA_ARGS__ \
} else { \
constexpr size_t BLOCK_SHAPE_Q = 128; \
constexpr size_t NUM_WARP_Q = 4; \
__VA_ARGS__ \
}
#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \
if (causal) { \
constexpr bool CAUSAL = true; \
__VA_ARGS__ \
} else { \
constexpr bool CAUSAL = false; \
__VA_ARGS__ \
}
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
if (enable_prefill) { \
constexpr bool ENABLE_PREFILL = 1; \
__VA_ARGS__ \
} else { \
constexpr bool ENABLE_PREFILL = 0; \
__VA_ARGS__ \
}
#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \
if (block_size == 64) { \
constexpr size_t BLOCK_SIZE = 64; \
__VA_ARGS__ \
}
#define DISPATCH_BLOCKSHAPE_Q_SYSTEM( \
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
if (block_shape_q <= 16) { \
constexpr size_t BLOCK_SHAPE_Q = 16; \
constexpr size_t NUM_WARP_Q = 1; \
__VA_ARGS__ \
} else if (block_shape_q <= 32) { \
constexpr size_t BLOCK_SHAPE_Q = 32; \
constexpr size_t NUM_WARP_Q = 1; \
__VA_ARGS__ \
}
template <typename T>
inline HOSTDEVICE T roundWithTiesToEven(T x) {
T xLower = floor(x);
T xUpper = ceil(x);
// x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
// even.
T dLower = x - xLower;
T dUpper = xUpper - x;
return static_cast<T>(
(dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
? xLower
: xUpper);
}
template <typename T, bool is_need_kv_quant, bool IsFP8, int RoundType = 0>
__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T value, const float max_bound, const float min_bound) {
uint8_t eight_bits;
float quant_value;
if constexpr (is_need_kv_quant) {
quant_value = static_cast<float>(scale * value);
} else {
quant_value = static_cast<float>(value);
}
if constexpr (RoundType == 0) {
quant_value = roundWithTiesToEven(quant_value);
} else {
quant_value = round(quant_value);
}
if constexpr (IsFP8) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
quant_value = quant_value > 448.0f ? 448.0f : quant_value;
quant_value = quant_value < -448.0f ? -448.0f : quant_value;
auto tmp = static_cast<__nv_fp8_e4m3>(quant_value);
eight_bits = *(reinterpret_cast<uint8_t*>(&tmp));
#else
printf("Do not support fp8 in arch < 890\n");
asm("trap;");
#endif
} else {
quant_value = quant_value > 127.0f ? 127.0f : quant_value;
quant_value = quant_value < -127.0f ? -127.0f : quant_value;
eight_bits = static_cast<uint8_t>(quant_value + 128.0f);
}
return eight_bits;
}
template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * result, const uint32_t& source){
if constexpr (IsFP8) {
convert_fp8(result, source);
} else {
convert_int8(result, source);
}
}
constexpr int kWarpSize = 32;
template<typename T>
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
*m2 += b_m2;
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
*m2 = thread_m2;
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
WelfordCombine1(b_m2, m2);
}
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}