Add switch to apply fine-grained per token quant fp8 (#3192)

Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
This commit is contained in:
RichardWooSJTU
2025-08-05 10:54:03 +08:00
committed by GitHub
parent 88596c0c63
commit e39159f3bd

View File

@@ -22,7 +22,8 @@ __global__ void quant_per_token_per_block(const T *input,
float *quanted_scale,
const int token_num,
const int hidden_size,
const int hidden_size_scale) {
const int hidden_size_scale,
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -58,6 +59,11 @@ __global__ void quant_per_token_per_block(const T *input,
// broadcast max_value
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
max_value_thread = max(max_value_thread, epsilon);
if (use_finegrained_range) {
max_value_thread *= 7.0f;
}
float scale_to_store = max_value_thread / MAX_VALUE;
// quant
#pragma unroll
@@ -89,6 +95,13 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
input.place());
const int gridx = min(132 * 8, token_num);
const int blockx = min(1024, hidden_size / 128 * 32);
bool use_finegrained_range = false;
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
if (env_var) {
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
@@ -97,7 +110,8 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale
hidden_size_scale,
use_finegrained_range
);
break;
case paddle::DataType::FLOAT16:
@@ -107,7 +121,8 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
quanted_scale.data<float>(),
token_num,
hidden_size,
hidden_size_scale
hidden_size_scale,
use_finegrained_range
);
break;
default:
@@ -124,7 +139,8 @@ __global__ void quant_per_token_per_block_padding(const T *input,
const int token_num,
const int padded_token_num,
const int hidden_size,
const int hidden_size_scale) {
const int hidden_size_scale,
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -160,6 +176,11 @@ __global__ void quant_per_token_per_block_padding(const T *input,
// broadcast max_value
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
max_value_thread = max(max_value_thread, epsilon);
if (use_finegrained_range) {
max_value_thread *= 7.0f;
}
float scale_to_store = max_value_thread / MAX_VALUE;
// quant
#pragma unroll
@@ -198,6 +219,13 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
input.place());
const int gridx = min(132 * 8, token_num);
const int blockx = min(1024, hidden_size / 128 * 32);
bool use_finegrained_range = false;
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
if (env_var) {
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
@@ -207,7 +235,8 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
token_num,
padded_token_num,
hidden_size,
hidden_size_scale
hidden_size_scale,
use_finegrained_range
);
break;
case paddle::DataType::FLOAT16:
@@ -218,7 +247,8 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
token_num,
padded_token_num,
hidden_size,
hidden_size_scale
hidden_size_scale,
use_finegrained_range
);
break;
default:
@@ -236,7 +266,8 @@ __global__ void masked_quant_per_token_per_block(const T *input,
const int token_num,
const int hidden_size,
const int hidden_size_scale,
const int num_max_tokens_per_expert) {
const int num_max_tokens_per_expert,
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -281,6 +312,11 @@ __global__ void masked_quant_per_token_per_block(const T *input,
// broadcast max_value
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
max_value_thread = max(max_value_thread, epsilon);
if (use_finegrained_range) {
max_value_thread *= 7.0f;
}
float scale_to_store = max_value_thread / MAX_VALUE;
// quant
#pragma unroll
@@ -317,6 +353,12 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
const int gridx = min(132 * 2, token_num);
const int blockx = min(1024, hidden_size / 128 * 32);
bool use_finegrained_range = false;
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
if (env_var) {
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
}
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
masked_quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
@@ -327,7 +369,8 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
token_num,
hidden_size,
hidden_size_scale,
num_max_tokens_per_expert
num_max_tokens_per_expert,
use_finegrained_range
);
break;
case paddle::DataType::FLOAT16:
@@ -339,7 +382,8 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
token_num,
hidden_size,
hidden_size_scale,
num_max_tokens_per_expert
num_max_tokens_per_expert,
use_finegrained_range
);
break;
default: