Files
FastDeploy/custom_ops/gpu_ops/quantization/common.cuh
2025-06-29 23:29:37 +00:00

160 lines
5.5 KiB
Plaintext

// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
#pragma once
#include "cuda.h"
#include "helper.h"
#include <cmath>
#include <cub/cub.cuh>
#include <cuda_runtime.h>
namespace fastdeploy {
// Vectorization containers
template <typename scalar_t> struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
template <typename quant_type_t> struct __align__(4) q8x4_t {
static_assert(std::is_same_v<quant_type_t, int8_t> ||
std::is_same_v<quant_type_t, phi::dtype::float8_e4m3fn>);
quant_type_t x;
quant_type_t y;
quant_type_t z;
quant_type_t w;
};
__device__ __forceinline__ float atomicMaxFloat(float *addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int *)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int *)addr, __float_as_uint(value)));
return old;
}
template <bool is_scale_inverted, typename fp8_type>
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-448, fmin(x, 448));
return static_cast<fp8_type>(r);
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template <typename scalar_t, typename fp8_type>
__global__ void segmented_max_reduction(float *__restrict__ scale,
const scalar_t *__restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
float tmp = 0.0f;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = fmax(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
__syncthreads();
// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / 448.f);
}
}
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const *__restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const *vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const *>(input);
int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(static_cast<float>(in_vec.x)));
absmax_val = max(absmax_val, fabs(static_cast<float>(in_vec.y)));
absmax_val = max(absmax_val, fabs(static_cast<float>(in_vec.z)));
absmax_val = max(absmax_val, fabs(static_cast<float>(in_vec.w)));
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(static_cast<float>(input[i])));
}
return absmax_val;
}
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
__device__ void scaled_fp8_conversion_vec(fp8_type *__restrict__ out,
scalar_t const *__restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
using float8x4_t = q8x4_t<fp8_type>;
// Vectorized input/output to better utilize memory bandwidth.
auto const *vectorized_in = reinterpret_cast<vec4_t<scalar_t> const *>(input);
auto *vectorized_out = reinterpret_cast<float8x4_t *>(out);
int64_t const num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(input[i]), scale);
}
}
} // namespace fastdeploy