mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-02 15:22:24 +08:00

Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* supports dynamic Cfp8 * add unittest
612 lines
24 KiB
Plaintext
612 lines
24 KiB
Plaintext
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
#pragma once
|
|
|
|
#include <cuda_bf16.h>
|
|
#include <cuda_fp8.h>
|
|
#include <cuda_runtime.h>
|
|
#include "mem_util.cuh"
|
|
|
|
struct AppendAttnMetaData {
|
|
int batch_size;
|
|
int block_size;
|
|
int q_num_heads;
|
|
int kv_num_heads;
|
|
int token_nums;
|
|
int head_dims;
|
|
int head_dims_v;
|
|
int max_blocks_per_seq;
|
|
const int *mask_offset = nullptr;
|
|
};
|
|
|
|
__forceinline__ __host__ __device__ int div_up(int a, int b) {
|
|
return (a + b - 1) / b;
|
|
}
|
|
|
|
enum PosEncMode { kNonePos, kRoPE, kAliBi };
|
|
|
|
enum CacheType { CacheT, CacheInt8Hw, CacheInt4CwZp };
|
|
|
|
template <typename T>
|
|
struct cascade_attn_type_traits {
|
|
using type = T;
|
|
};
|
|
|
|
template <>
|
|
struct cascade_attn_type_traits<phi::dtype::bfloat16> {
|
|
using type = __nv_bfloat16;
|
|
};
|
|
|
|
template <>
|
|
struct cascade_attn_type_traits<phi::dtype::float16> {
|
|
using type = half;
|
|
};
|
|
|
|
template <>
|
|
struct cascade_attn_type_traits<phi::dtype::float8_e4m3fn> {
|
|
using type = __nv_fp8_e4m3;
|
|
};
|
|
|
|
template <typename T>
|
|
struct cascade_attn_nv_type2_traits {
|
|
using type = T;
|
|
};
|
|
|
|
template <>
|
|
struct cascade_attn_nv_type2_traits<__nv_bfloat16> {
|
|
using type = __nv_bfloat162;
|
|
};
|
|
|
|
template <>
|
|
struct cascade_attn_nv_type2_traits<half> {
|
|
using type = half2;
|
|
};
|
|
|
|
template <CacheType cache_type>
|
|
struct vec_traits {
|
|
using type = b128_t;
|
|
};
|
|
|
|
template <>
|
|
struct vec_traits<CacheType::CacheInt8Hw> {
|
|
using type = b64_t;
|
|
};
|
|
|
|
template <>
|
|
struct vec_traits<CacheType::CacheInt4CwZp> {
|
|
using type = b32_t;
|
|
};
|
|
|
|
template <typename T, CacheType cache_type>
|
|
struct cache_type_traits {
|
|
using type = T;
|
|
};
|
|
|
|
template <typename T>
|
|
struct cache_type_traits<T, CacheType::CacheInt8Hw> {
|
|
using type = uint8_t;
|
|
};
|
|
|
|
template <typename T>
|
|
struct cache_type_traits<T, CacheType::CacheInt4CwZp> {
|
|
using type = uint8_t;
|
|
};
|
|
|
|
__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
|
|
uint32_t y) {
|
|
return (x > y) ? x - y : 0U;
|
|
}
|
|
|
|
/******************************FASTER CAST*********************************/
|
|
|
|
inline __device__ static void convert_fp8(__nv_bfloat16* result, const uint32_t& source) {
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
|
uint32_t dest0;
|
|
uint32_t dest1;
|
|
asm volatile( \
|
|
"{\n" \
|
|
".reg .b16 lo, hi;\n" \
|
|
"mov.b32 {lo, hi}, %2;\n" \
|
|
"cvt.rn.f16x2.e4m3x2 %0, lo;\n" \
|
|
"cvt.rn.f16x2.e4m3x2 %1, hi;\n" \
|
|
"}\n" : "=r"(dest0), "=r"(dest1) : "r"(source));
|
|
|
|
((nv_bfloat162*)(result))[0] = __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0]));
|
|
((nv_bfloat162*)(result))[1] = __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0]));
|
|
#else
|
|
printf("Do not support fp8 in arch < 890\n");
|
|
asm("trap;");
|
|
#endif
|
|
|
|
}
|
|
|
|
inline __device__ static void convert_fp8(half* result, const uint32_t& source) {
|
|
printf("Do not support fp8 to half although it's very easy.\n");
|
|
}
|
|
|
|
inline __device__ static void convert_int8(
|
|
__nv_bfloat16* result, const uint32_t& source) { // 4 int8 each time
|
|
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(result);
|
|
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
|
|
|
static constexpr uint32_t fp32_base = 0x4B000000;
|
|
float fp32_intermediates[4];
|
|
|
|
uint32_t* fp32_intermediates_casted =
|
|
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
|
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
|
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
|
|
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
|
|
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
|
|
|
#pragma unroll
|
|
for (int ii = 0; ii < 4; ++ii) {
|
|
fp32_intermediates[ii] -= 8388736.f; // (8388608.f + 128.f);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int ii = 0; ii < 2; ++ii) {
|
|
bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0],
|
|
fp32_intermediates_casted[2 * ii + 1],
|
|
0x7632);
|
|
}
|
|
}
|
|
|
|
inline __device__ static void convert_int8(
|
|
half* result, const uint32_t& source) { // 4 int8 each time
|
|
uint32_t* fp16_result_ptr = reinterpret_cast<uint32_t*>(result);
|
|
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
|
static constexpr uint32_t mask_for_elt_01 = 0x5150;
|
|
static constexpr uint32_t mask_for_elt_23 = 0x5352;
|
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
|
|
|
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
|
: "=r"(fp16_result_ptr[0])
|
|
: "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
|
|
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
|
: "=r"(fp16_result_ptr[1])
|
|
: "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
|
|
|
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
|
: "=r"(fp16_result_ptr[0])
|
|
: "r"(fp16_result_ptr[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
|
: "=r"(fp16_result_ptr[1])
|
|
: "r"(fp16_result_ptr[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
|
}
|
|
|
|
inline __device__ static void convert_int4(
|
|
__nv_bfloat16* result, const uint32_t& source) { // 8 int4 each time
|
|
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(result);
|
|
|
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
|
static constexpr uint32_t MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4
|
|
static constexpr uint32_t I4s_TO_FP32s_MAGIC_NUM = 0x43434343;
|
|
static constexpr uint32_t mask_for_elt_01 = 0x5150;
|
|
static constexpr uint32_t mask_for_elt_23 = 0x5352;
|
|
|
|
uint32_t tmp1 = source & MASK; // 0 1 2 3
|
|
uint32_t tmp2 = source >> 4 & MASK; // 4 5 6 7
|
|
|
|
bf16_result_ptr[0] = __byte_perm(tmp1,
|
|
I4s_TO_FP32s_MAGIC_NUM,
|
|
mask_for_elt_01); // 0 1
|
|
bf16_result_ptr[1] = __byte_perm(tmp1,
|
|
I4s_TO_FP32s_MAGIC_NUM,
|
|
mask_for_elt_23); // 2 3
|
|
bf16_result_ptr[2] = __byte_perm(tmp2,
|
|
I4s_TO_FP32s_MAGIC_NUM,
|
|
mask_for_elt_01); // 4 5
|
|
bf16_result_ptr[3] = __byte_perm(tmp2,
|
|
I4s_TO_FP32s_MAGIC_NUM,
|
|
mask_for_elt_23); // 6 7
|
|
}
|
|
|
|
inline __device__ static void convert_int4(
|
|
half* result, const uint32_t& source) { // 7 5 3 1 6 4 2 0
|
|
uint32_t* fp16_result_ptr = reinterpret_cast<uint32_t*>(result);
|
|
|
|
static constexpr uint32_t MASK =
|
|
0x0f0f0f0f; // 0xf -> 0b1111 select 0,1; 7 5 3 1 6 4 2 0
|
|
static constexpr uint32_t I4s_TO_FP32s_MAGIC_NUM = 0x64646464;
|
|
static constexpr uint32_t mask_for_elt_01 = 0x5150;
|
|
static constexpr uint32_t mask_for_elt_23 = 0x5352;
|
|
|
|
uint32_t tmp1 = source & MASK; // 0 1 2 3
|
|
uint32_t tmp2 = source >> 4 & MASK; // 4 5 6 7
|
|
fp16_result_ptr[0] = __byte_perm(tmp1,
|
|
I4s_TO_FP32s_MAGIC_NUM,
|
|
mask_for_elt_01); // 0 1
|
|
fp16_result_ptr[1] = __byte_perm(tmp1,
|
|
I4s_TO_FP32s_MAGIC_NUM,
|
|
mask_for_elt_23); // 2 3
|
|
fp16_result_ptr[2] = __byte_perm(tmp2,
|
|
I4s_TO_FP32s_MAGIC_NUM,
|
|
mask_for_elt_01); // 4 5
|
|
fp16_result_ptr[3] = __byte_perm(tmp2,
|
|
I4s_TO_FP32s_MAGIC_NUM,
|
|
mask_for_elt_23); // 6 7
|
|
}
|
|
|
|
/******************* vec_t type cast *******************/
|
|
|
|
template <typename dst_t, typename src_t, size_t vec_size>
|
|
__forceinline__ __host__ __device__ void vec_cast(dst_t* dst,
|
|
const src_t* src) {
|
|
#pragma unroll
|
|
for (size_t i = 0; i < vec_size; ++i) {
|
|
dst[i] = src[i];
|
|
}
|
|
}
|
|
|
|
template <size_t vec_size>
|
|
__forceinline__ __host__ __device__ void vec_cast<float, half>(
|
|
float* dst, const half* src) {
|
|
#pragma unroll
|
|
for (size_t i = 0; i < vec_size / 2; ++i) {
|
|
((float2*)dst)[i] = __half22float2(((half2*)src)[i]);
|
|
}
|
|
}
|
|
|
|
template <size_t vec_size>
|
|
__forceinline__ __host__ __device__ void vec_cast<half, float>(
|
|
half* dst, const float* src) {
|
|
#pragma unroll
|
|
for (size_t i = 0; i < vec_size / 2; ++i) {
|
|
((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]);
|
|
}
|
|
}
|
|
|
|
template <size_t vec_size>
|
|
__forceinline__ __host__ __device__ void vec_cast<float, nv_bfloat16>(
|
|
float* dst, const nv_bfloat16* src) {
|
|
#pragma unroll
|
|
for (size_t i = 0; i < vec_size / 2; ++i) {
|
|
((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]);
|
|
}
|
|
}
|
|
|
|
template <size_t vec_size>
|
|
__forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
|
nv_bfloat16* dst, const float* src) {
|
|
#pragma unroll
|
|
for (size_t i = 0; i < vec_size / 2; ++i) {
|
|
((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]);
|
|
}
|
|
}
|
|
|
|
#define CHECK_CUDA_CALL(func, ...) \
|
|
{ \
|
|
cudaError_t e = (func); \
|
|
if (e != cudaSuccess) { \
|
|
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
|
|
<< ") " << __FILE__ << ": line " << __LINE__ \
|
|
<< " at function " << STR(func) << std::endl; \
|
|
return e; \
|
|
} \
|
|
}
|
|
|
|
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
|
switch (head_dim) { \
|
|
case 128: { \
|
|
constexpr size_t HEAD_DIM = 128; \
|
|
__VA_ARGS__ \
|
|
break; \
|
|
} \
|
|
default: { \
|
|
PD_THROW("not support the head_dim"); \
|
|
} \
|
|
}
|
|
|
|
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
|
switch (head_dim) { \
|
|
case 128: { \
|
|
constexpr size_t HEAD_DIM = 128; \
|
|
__VA_ARGS__ \
|
|
break; \
|
|
} \
|
|
case 192: { \
|
|
constexpr size_t HEAD_DIM = 192; \
|
|
__VA_ARGS__ \
|
|
break; \
|
|
} \
|
|
default: { \
|
|
PD_THROW("not support the head_dim: ", head_dim); \
|
|
} \
|
|
}
|
|
|
|
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
|
switch (head_dim) { \
|
|
case 128: { \
|
|
constexpr size_t HEAD_DIM = 128; \
|
|
__VA_ARGS__ \
|
|
break; \
|
|
} \
|
|
case 192: { \
|
|
constexpr size_t HEAD_DIM = 192; \
|
|
__VA_ARGS__ \
|
|
break; \
|
|
} \
|
|
case 512: { \
|
|
constexpr size_t HEAD_DIM = 512; \
|
|
__VA_ARGS__ \
|
|
break; \
|
|
} \
|
|
case 576: { \
|
|
constexpr size_t HEAD_DIM = 576; \
|
|
__VA_ARGS__ \
|
|
break; \
|
|
} \
|
|
default: { \
|
|
PD_THROW("not support the head_dim: ", head_dim); \
|
|
} \
|
|
}
|
|
|
|
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
|
if (num_stage == 2) { \
|
|
constexpr size_t NUM_STAGE = 2; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
PD_THROW("not support the num_stage: ", num_stage); \
|
|
}
|
|
|
|
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
|
|
if (cache_type == 0) { \
|
|
constexpr CacheType cache_type_now = CacheType::CacheT; \
|
|
constexpr size_t cache_bytes = 16; \
|
|
__VA_ARGS__ \
|
|
} else if (cache_type == 1) { \
|
|
constexpr CacheType cache_type_now = CacheType::CacheInt8Hw; \
|
|
constexpr size_t cache_bytes = 8; \
|
|
__VA_ARGS__ \
|
|
} else if (cache_type == 2) { \
|
|
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
|
|
constexpr size_t cache_bytes = 4; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
PD_THROW("not support the cache_type: ", cache_type); \
|
|
}
|
|
|
|
|
|
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
|
|
if (deal_each_time == 32) { \
|
|
constexpr size_t DEAL_EACH_TIME = 32; \
|
|
__VA_ARGS__ \
|
|
} else if (deal_each_time == 64) { \
|
|
constexpr size_t DEAL_EACH_TIME = 64; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
PD_THROW("not support the deal_each_time", deal_each_time); \
|
|
}
|
|
|
|
#define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \
|
|
if (num_threads == 128) { \
|
|
constexpr size_t NUM_THREADS = 128; \
|
|
__VA_ARGS__ \
|
|
} else if (num_threads == 256) { \
|
|
constexpr size_t NUM_THREADS = 256; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
PD_THROW("not support the num_threads", num_threads); \
|
|
}
|
|
|
|
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
|
if (group_size == 1) { \
|
|
constexpr size_t GROUP_SIZE = 1; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 2) { \
|
|
constexpr size_t GROUP_SIZE = 2; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 3) { \
|
|
constexpr size_t GROUP_SIZE = 3; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 4) { \
|
|
constexpr size_t GROUP_SIZE = 4; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 5) { \
|
|
constexpr size_t GROUP_SIZE = 5; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 6) { \
|
|
constexpr size_t GROUP_SIZE = 6; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 7) { \
|
|
constexpr size_t GROUP_SIZE = 7; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 8) { \
|
|
constexpr size_t GROUP_SIZE = 8; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 12) { \
|
|
constexpr size_t GROUP_SIZE = 12; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 14) { \
|
|
constexpr size_t GROUP_SIZE = 14; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 16) { \
|
|
constexpr size_t GROUP_SIZE = 16; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
PD_THROW("not support the group_size", group_size); \
|
|
}
|
|
|
|
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
|
|
if (is_dynamic_cfp8) { \
|
|
constexpr bool IsDynamicC8 = true; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
constexpr bool IsDynamicC8 = false; \
|
|
__VA_ARGS__ \
|
|
}
|
|
|
|
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
|
if (group_size == 8) { \
|
|
constexpr size_t GROUP_SIZE = 8; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 16) { \
|
|
constexpr size_t GROUP_SIZE = 16; \
|
|
__VA_ARGS__ \
|
|
} else if (group_size == 128) { \
|
|
constexpr size_t GROUP_SIZE = 128; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
PD_THROW("not support the group_size: ", group_size); \
|
|
}
|
|
|
|
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
|
if (block_shape_q <= 16) { \
|
|
constexpr size_t BLOCK_SHAPE_Q = 16; \
|
|
constexpr size_t NUM_WARP_Q = 1; \
|
|
__VA_ARGS__ \
|
|
} else if (block_shape_q <= 32) { \
|
|
constexpr size_t BLOCK_SHAPE_Q = 32; \
|
|
constexpr size_t NUM_WARP_Q = 1; \
|
|
__VA_ARGS__ \
|
|
} else if (block_shape_q <= 64) { \
|
|
constexpr size_t BLOCK_SHAPE_Q = 64; \
|
|
constexpr size_t NUM_WARP_Q = 4; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
constexpr size_t BLOCK_SHAPE_Q = 128; \
|
|
constexpr size_t NUM_WARP_Q = 4; \
|
|
__VA_ARGS__ \
|
|
}
|
|
|
|
#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \
|
|
if (causal) { \
|
|
constexpr bool CAUSAL = true; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
constexpr bool CAUSAL = false; \
|
|
__VA_ARGS__ \
|
|
}
|
|
|
|
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
|
|
if (enable_prefill) { \
|
|
constexpr bool ENABLE_PREFILL = 1; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
constexpr bool ENABLE_PREFILL = 0; \
|
|
__VA_ARGS__ \
|
|
}
|
|
|
|
#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \
|
|
if (block_size == 64) { \
|
|
constexpr size_t BLOCK_SIZE = 64; \
|
|
__VA_ARGS__ \
|
|
}
|
|
|
|
#define DISPATCH_BLOCKSHAPE_Q_SYSTEM( \
|
|
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
|
if (block_shape_q <= 16) { \
|
|
constexpr size_t BLOCK_SHAPE_Q = 16; \
|
|
constexpr size_t NUM_WARP_Q = 1; \
|
|
__VA_ARGS__ \
|
|
} else if (block_shape_q <= 32) { \
|
|
constexpr size_t BLOCK_SHAPE_Q = 32; \
|
|
constexpr size_t NUM_WARP_Q = 1; \
|
|
__VA_ARGS__ \
|
|
}
|
|
|
|
template <typename T>
|
|
inline HOSTDEVICE T roundWithTiesToEven(T x) {
|
|
T xLower = floor(x);
|
|
T xUpper = ceil(x);
|
|
// x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
|
|
// even.
|
|
T dLower = x - xLower;
|
|
T dUpper = xUpper - x;
|
|
return static_cast<T>(
|
|
(dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
|
|
? xLower
|
|
: xUpper);
|
|
}
|
|
|
|
|
|
template <typename T, bool is_need_kv_quant, bool IsFP8, int RoundType = 0>
|
|
__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T value, const float max_bound, const float min_bound) {
|
|
uint8_t eight_bits;
|
|
float quant_value;
|
|
if constexpr (is_need_kv_quant) {
|
|
quant_value = static_cast<float>(scale * value);
|
|
} else {
|
|
quant_value = static_cast<float>(value);
|
|
}
|
|
if constexpr (RoundType == 0) {
|
|
quant_value = roundWithTiesToEven(quant_value);
|
|
} else {
|
|
quant_value = round(quant_value);
|
|
}
|
|
|
|
if constexpr (IsFP8) {
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
|
quant_value = quant_value > 448.0f ? 448.0f : quant_value;
|
|
quant_value = quant_value < -448.0f ? -448.0f : quant_value;
|
|
auto tmp = static_cast<__nv_fp8_e4m3>(quant_value);
|
|
eight_bits = *(reinterpret_cast<uint8_t*>(&tmp));
|
|
#else
|
|
printf("Do not support fp8 in arch < 890\n");
|
|
asm("trap;");
|
|
#endif
|
|
} else {
|
|
quant_value = quant_value > 127.0f ? 127.0f : quant_value;
|
|
quant_value = quant_value < -127.0f ? -127.0f : quant_value;
|
|
eight_bits = static_cast<uint8_t>(quant_value + 128.0f);
|
|
}
|
|
return eight_bits;
|
|
}
|
|
|
|
|
|
template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * result, const uint32_t& source){
|
|
if constexpr (IsFP8) {
|
|
convert_fp8(result, source);
|
|
} else {
|
|
convert_int8(result, source);
|
|
}
|
|
}
|
|
|
|
constexpr int kWarpSize = 32;
|
|
|
|
template<typename T>
|
|
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
|
|
*m2 += b_m2;
|
|
}
|
|
|
|
template<typename T, int thread_group_width = kWarpSize>
|
|
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
|
|
*m2 = thread_m2;
|
|
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
|
|
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
|
|
WelfordCombine1(b_m2, m2);
|
|
}
|
|
}
|
|
|
|
template<typename T, int thread_group_width = kWarpSize>
|
|
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
|
|
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
|
|
}
|
|
|
|
template <typename T>
|
|
__inline__ __device__ T Rsqrt(T x);
|
|
|
|
template <>
|
|
__inline__ __device__ float Rsqrt<float>(float x) {
|
|
return rsqrt(x);
|
|
}
|
|
|
|
template <>
|
|
__inline__ __device__ double Rsqrt<double>(double x) {
|
|
return rsqrt(x);
|
|
}
|