mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
235
custom_ops/gpu_ops/quantization/common.cu
Normal file
235
custom_ops/gpu_ops/quantization/common.cu
Normal file
@@ -0,0 +1,235 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
|
||||
|
||||
|
||||
#include "quantization/common.cuh"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel(fp8_type *__restrict__ out,
|
||||
const scalar_t *__restrict__ input,
|
||||
const float *__restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// Invert the scale so that we can use multiplications to avoid expensive
|
||||
// division.
|
||||
const float inverted_scale = 1.0f / (*scale);
|
||||
scaled_fp8_conversion_vec<scalar_t, true>(
|
||||
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
fp8_type *__restrict__ out, float *__restrict__ scale,
|
||||
scalar_t const *__restrict__ input, float scale_ub, const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
|
||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
||||
scalar_t const *__restrict__ token_input = &input[offset];
|
||||
fp8_type *__restrict__ token_output = &out[offset];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
bool const can_vectorize = hidden_size % 4 == 0;
|
||||
|
||||
float absmax_val = 0.0f;
|
||||
if (can_vectorize) {
|
||||
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
float const x = static_cast<float>(token_input[i]);
|
||||
absmax_val = max(absmax_val, fabs(x));
|
||||
}
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
float const block_absmax_val_maybe =
|
||||
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
if (scale_ub > 0) {
|
||||
token_scale = min(block_absmax_val_maybe, scale_ub);
|
||||
} else {
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
// token_scale = max(token_scale / 448.f,
|
||||
// min_scaling_factor<fp8_type>::val());
|
||||
token_scale = token_scale / 448.f;
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Note that we don't use inverted scales so we can match FBGemm impl.
|
||||
if (can_vectorize) {
|
||||
scaled_fp8_conversion_vec<scalar_t, false>(
|
||||
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
|
||||
static_cast<float>(token_input[i]), token_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
||||
void StaticScaledFp8Quant(paddle::Tensor &out, // [..., d]
|
||||
paddle::Tensor const &input, // [..., d]
|
||||
paddle::Tensor const &scale) // [1]
|
||||
{
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
using fp8_t = phi::dtype::float8_e4m3fn;
|
||||
auto rank = input.dims().size();
|
||||
int64_t num_tokens = input.numel() / input.dims()[rank - 1];
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
using scalar_t = float;
|
||||
fastdeploy::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), input.data<scalar_t>(),
|
||||
scale.data<float>(), num_elems);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using scalar_t = phi::dtype::float16;
|
||||
fastdeploy::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), input.data<scalar_t>(),
|
||||
scale.data<float>(), num_elems);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using scalar_t = phi::dtype::bfloat16;
|
||||
fastdeploy::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), input.data<scalar_t>(),
|
||||
scale.data<float>(), num_elems);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicScaledFp8Quant(paddle::Tensor &out, // [..., d]
|
||||
paddle::Tensor const &input, // [..., d]
|
||||
paddle::Tensor &scale) // [1]
|
||||
{
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
using fp8_t = phi::dtype::float8_e4m3fn;
|
||||
auto rank = input.dims().size();
|
||||
int64_t num_tokens = input.numel() / input.dims()[rank - 1];
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
using scalar_t = float;
|
||||
fastdeploy::segmented_max_reduction<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(scale.data<float>(),
|
||||
input.data<scalar_t>(), num_elems);
|
||||
fastdeploy::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), input.data<scalar_t>(),
|
||||
scale.data<float>(), num_elems);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using scalar_t = phi::dtype::float16;
|
||||
fastdeploy::segmented_max_reduction<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(scale.data<float>(),
|
||||
input.data<scalar_t>(), num_elems);
|
||||
fastdeploy::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), input.data<scalar_t>(),
|
||||
scale.data<float>(), num_elems);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using scalar_t = phi::dtype::bfloat16;
|
||||
fastdeploy::segmented_max_reduction<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(scale.data<float>(),
|
||||
input.data<scalar_t>(), num_elems);
|
||||
fastdeploy::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), input.data<scalar_t>(),
|
||||
scale.data<float>(), num_elems);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d]
|
||||
paddle::Tensor const &input, // [..., d]
|
||||
paddle::Tensor &scales, float scale_ub) {
|
||||
PD_CHECK(input.is_contiguous());
|
||||
PD_CHECK(out.is_contiguous());
|
||||
PD_CHECK(out.dtype() == paddle::DataType::FLOAT8_E4M3FN);
|
||||
using fp8_t = phi::dtype::float8_e4m3fn;
|
||||
auto rank = input.dims().size();
|
||||
int const hidden_size = input.dims()[rank - 1];
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
using scalar_t = float;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using scalar_t = phi::dtype::float16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using scalar_t = phi::dtype::bfloat16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(static_scaled_fp8_quant)
|
||||
.Inputs({"out", "input", "scale"})
|
||||
.Outputs({"out_q"})
|
||||
.SetInplaceMap({{"out", "out_q"}})
|
||||
.SetKernelFn(PD_KERNEL(StaticScaledFp8Quant));
|
||||
|
||||
PD_BUILD_STATIC_OP(dynamic_scaled_fp8_quant)
|
||||
.Inputs({"out", "input", "scale"})
|
||||
.Outputs({"out_q", "out_scale"})
|
||||
.SetInplaceMap({{"out", "out_q"},
|
||||
{"scale", "out_scale"}})
|
||||
.SetKernelFn(PD_KERNEL(DynamicScaledFp8Quant));
|
||||
|
||||
PD_BUILD_STATIC_OP(dynamic_per_token_scaled_fp8_quant)
|
||||
.Inputs({"out", "input", "scale"})
|
||||
.Attrs({"scale_ub: float"})
|
||||
.Outputs({"out_q"})
|
||||
.SetInplaceMap({{"out", "out_q"}})
|
||||
.SetKernelFn(PD_KERNEL(DynamicPerTokenScaledFp8Quant));
|
159
custom_ops/gpu_ops/quantization/common.cuh
Normal file
159
custom_ops/gpu_ops/quantization/common.cuh
Normal file
@@ -0,0 +1,159 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cuda.h"
|
||||
#include "helper.h"
|
||||
#include <cmath>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
// Vectorization containers
|
||||
template <typename scalar_t> struct __align__(8) vec4_t {
|
||||
scalar_t x;
|
||||
scalar_t y;
|
||||
scalar_t z;
|
||||
scalar_t w;
|
||||
};
|
||||
|
||||
template <typename quant_type_t> struct __align__(4) q8x4_t {
|
||||
static_assert(std::is_same_v<quant_type_t, int8_t> ||
|
||||
std::is_same_v<quant_type_t, phi::dtype::float8_e4m3fn>);
|
||||
quant_type_t x;
|
||||
quant_type_t y;
|
||||
quant_type_t z;
|
||||
quant_type_t w;
|
||||
};
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float *addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0)
|
||||
? __int_as_float(atomicMax((int *)addr, __float_as_int(value)))
|
||||
: __uint_as_float(
|
||||
atomicMin((unsigned int *)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted, typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
} else {
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-448, fmin(x, 448));
|
||||
return static_cast<fp8_type>(r);
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||
// reduction tree and the memory in scale is atomically updated.
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction(float *__restrict__ scale,
|
||||
const scalar_t *__restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
// the current thread in cache[threadIdx.x]
|
||||
float tmp = 0.0f;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = fmax(tmp, fabs(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now perform parallel reduction within the thread block
|
||||
int ib = blockDim.x / 2;
|
||||
while (ib != 0) {
|
||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||
}
|
||||
__syncthreads();
|
||||
ib /= 2;
|
||||
}
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / 448.f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ float thread_max_vec(scalar_t const *__restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const *vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const *>(input);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
float absmax_val = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
absmax_val = max(absmax_val, fabs(static_cast<float>(in_vec.x)));
|
||||
absmax_val = max(absmax_val, fabs(static_cast<float>(in_vec.y)));
|
||||
absmax_val = max(absmax_val, fabs(static_cast<float>(in_vec.z)));
|
||||
absmax_val = max(absmax_val, fabs(static_cast<float>(in_vec.w)));
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
absmax_val = max(absmax_val, fabs(static_cast<float>(input[i])));
|
||||
}
|
||||
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
|
||||
__device__ void scaled_fp8_conversion_vec(fp8_type *__restrict__ out,
|
||||
scalar_t const *__restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
using float8x4_t = q8x4_t<fp8_type>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const *vectorized_in = reinterpret_cast<vec4_t<scalar_t> const *>(input);
|
||||
auto *vectorized_out = reinterpret_cast<float8x4_t *>(out);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
Reference in New Issue
Block a user