diff --git a/custom_ops/gpu_ops/per_token_quant_fp8.cu b/custom_ops/gpu_ops/per_token_quant_fp8.cu index 9a16d4d36..3199b2be9 100644 --- a/custom_ops/gpu_ops/per_token_quant_fp8.cu +++ b/custom_ops/gpu_ops/per_token_quant_fp8.cu @@ -22,7 +22,8 @@ __global__ void quant_per_token_per_block(const T *input, float *quanted_scale, const int token_num, const int hidden_size, - const int hidden_size_scale) { + const int hidden_size_scale, + const bool use_finegrained_range) { const int bid = blockIdx.x; const int tid = threadIdx.x; const int warp_id = tid / 32; @@ -58,6 +59,11 @@ __global__ void quant_per_token_per_block(const T *input, // broadcast max_value max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); max_value_thread = max(max_value_thread, epsilon); + + if (use_finegrained_range) { + max_value_thread *= 7.0f; + } + float scale_to_store = max_value_thread / MAX_VALUE; // quant #pragma unroll @@ -89,6 +95,13 @@ std::vector PerTokenQuant(paddle::Tensor& input, input.place()); const int gridx = min(132 * 8, token_num); const int blockx = min(1024, hidden_size / 128 * 32); + + bool use_finegrained_range = false; + char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"); + if (env_var) { + use_finegrained_range = static_cast(std::stoi(env_var)); + } + switch (input.dtype()) { case paddle::DataType::BFLOAT16: quant_per_token_per_block<<>>( @@ -97,7 +110,8 @@ std::vector PerTokenQuant(paddle::Tensor& input, quanted_scale.data(), token_num, hidden_size, - hidden_size_scale + hidden_size_scale, + use_finegrained_range ); break; case paddle::DataType::FLOAT16: @@ -107,7 +121,8 @@ std::vector PerTokenQuant(paddle::Tensor& input, quanted_scale.data(), token_num, hidden_size, - hidden_size_scale + hidden_size_scale, + use_finegrained_range ); break; default: @@ -124,7 +139,8 @@ __global__ void quant_per_token_per_block_padding(const T *input, const int token_num, const int padded_token_num, const int hidden_size, - const int hidden_size_scale) { + const int hidden_size_scale, + const bool use_finegrained_range) { const int bid = blockIdx.x; const int tid = threadIdx.x; const int warp_id = tid / 32; @@ -160,6 +176,11 @@ __global__ void quant_per_token_per_block_padding(const T *input, // broadcast max_value max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); max_value_thread = max(max_value_thread, epsilon); + + if (use_finegrained_range) { + max_value_thread *= 7.0f; + } + float scale_to_store = max_value_thread / MAX_VALUE; // quant #pragma unroll @@ -198,6 +219,13 @@ std::vector PerTokenQuantPadding(paddle::Tensor& input, input.place()); const int gridx = min(132 * 8, token_num); const int blockx = min(1024, hidden_size / 128 * 32); + + bool use_finegrained_range = false; + char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"); + if (env_var) { + use_finegrained_range = static_cast(std::stoi(env_var)); + } + switch (input.dtype()) { case paddle::DataType::BFLOAT16: quant_per_token_per_block_padding<<>>( @@ -207,7 +235,8 @@ std::vector PerTokenQuantPadding(paddle::Tensor& input, token_num, padded_token_num, hidden_size, - hidden_size_scale + hidden_size_scale, + use_finegrained_range ); break; case paddle::DataType::FLOAT16: @@ -218,7 +247,8 @@ std::vector PerTokenQuantPadding(paddle::Tensor& input, token_num, padded_token_num, hidden_size, - hidden_size_scale + hidden_size_scale, + use_finegrained_range ); break; default: @@ -236,7 +266,8 @@ __global__ void masked_quant_per_token_per_block(const T *input, const int token_num, const int hidden_size, const int hidden_size_scale, - const int num_max_tokens_per_expert) { + const int num_max_tokens_per_expert, + const bool use_finegrained_range) { const int bid = blockIdx.x; const int tid = threadIdx.x; const int warp_id = tid / 32; @@ -281,6 +312,11 @@ __global__ void masked_quant_per_token_per_block(const T *input, // broadcast max_value max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); max_value_thread = max(max_value_thread, epsilon); + + if (use_finegrained_range) { + max_value_thread *= 7.0f; + } + float scale_to_store = max_value_thread / MAX_VALUE; // quant #pragma unroll @@ -317,6 +353,12 @@ std::vector MaskedPerTokenQuant(paddle::Tensor& input, const int gridx = min(132 * 2, token_num); const int blockx = min(1024, hidden_size / 128 * 32); + bool use_finegrained_range = false; + char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"); + if (env_var) { + use_finegrained_range = static_cast(std::stoi(env_var)); + } + switch (input.dtype()) { case paddle::DataType::BFLOAT16: masked_quant_per_token_per_block<<>>( @@ -327,7 +369,8 @@ std::vector MaskedPerTokenQuant(paddle::Tensor& input, token_num, hidden_size, hidden_size_scale, - num_max_tokens_per_expert + num_max_tokens_per_expert, + use_finegrained_range ); break; case paddle::DataType::FLOAT16: @@ -339,7 +382,8 @@ std::vector MaskedPerTokenQuant(paddle::Tensor& input, token_num, hidden_size, hidden_size_scale, - num_max_tokens_per_expert + num_max_tokens_per_expert, + use_finegrained_range ); break; default: