Revert "[Feature] add ue8m0 for per_token_quant_fp8 (#5563)" (#5611)

This reverts commit 73e1d6aa90.
This commit is contained in:
Yuanle Liu
2025-12-17 13:59:06 +08:00
committed by GitHub
parent 21fa2baa51
commit cdc0004894
4 changed files with 56 additions and 193 deletions

View File

@@ -284,16 +284,13 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const int token_nums_this_rank_padded);
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
const int block_size,
const bool use_ue8m0);
const int block_size);
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
const int block_size,
const bool use_ue8m0);
const int block_size);
std::vector<paddle::Tensor> MaskedPerTokenQuant(
paddle::Tensor& input,
paddle::Tensor& recv_expert_count,
const int block_size,
const bool use_ue8m0);
const int block_size);
std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::Tensor& ffn_out,
@@ -1235,14 +1232,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&PerTokenQuant,
py::arg("input"),
py::arg("block_size"),
py::arg("use_ue8m0") = false,
"per token per block quant");
m.def("per_token_quant_padding",
&PerTokenQuantPadding,
py::arg("input"),
py::arg("block_size"),
py::arg("use_ue8m0") = false,
"per token per block quant and padding transpose scale");
m.def("masked_per_token_quant",
@@ -1250,7 +1245,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("input"),
py::arg("recv_expert_count"),
py::arg("block_size"),
py::arg("use_ue8m0") = false,
"per token per block quant");
#ifdef ENABLE_MACHETE

View File

@@ -16,16 +16,6 @@
constexpr float epsilon = 1e-10;
__device__ __forceinline__ float ceil_to_ue8m0(float s) {
int exp;
frexpf(s, &exp);
float pow2 = ldexpf(1.0f, exp - 1);
if (pow2 < s) {
pow2 = ldexpf(1.0f, exp);
}
return pow2;
}
template <typename T>
__global__ void quant_per_token_per_block(
const T *input,
@@ -34,8 +24,7 @@ __global__ void quant_per_token_per_block(
const int token_num,
const int hidden_size,
const int hidden_size_scale,
const bool use_finegrained_range,
const bool use_ue8m0) {
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -94,14 +83,11 @@ __global__ void quant_per_token_per_block(
}
float scale_to_store = max_value_thread / MAX_VALUE;
if (use_ue8m0) {
scale_to_store = ceil_to_ue8m0(scale_to_store);
}
// quant
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
// store
if (is_valid_data)
@@ -116,8 +102,7 @@ __global__ void quant_per_token_per_block(
}
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int block_size,
const bool use_ue8m0) {
const int block_size) {
auto input_dim = input.dims();
const int token_num = input_dim[0];
const int hidden_size = input_dim[1];
@@ -147,8 +132,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
@@ -158,8 +142,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
@@ -176,8 +159,7 @@ __global__ void quant_per_token_per_block_padding(
const int padded_token_num,
const int hidden_size,
const int hidden_size_scale,
const bool use_finegrained_range,
const bool use_ue8m0) {
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -227,14 +209,11 @@ __global__ void quant_per_token_per_block_padding(
}
float scale_to_store = max_value_thread / MAX_VALUE;
if (use_ue8m0) {
scale_to_store = ceil_to_ue8m0(scale_to_store);
}
// quant
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
// store
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
@@ -247,8 +226,7 @@ __global__ void quant_per_token_per_block_padding(
}
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
const int block_size,
const bool use_ue8m0) {
const int block_size) {
using ScaleDtype = float;
auto input_dim = input.dims();
@@ -291,8 +269,7 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
@@ -303,8 +280,7 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
@@ -344,8 +320,7 @@ __global__ void masked_quant_per_token_per_block(
const int hidden_size,
const int hidden_size_scale,
const int num_max_tokens_per_expert,
const bool use_finegrained_range,
const bool use_ue8m0) {
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -407,14 +382,11 @@ __global__ void masked_quant_per_token_per_block(
}
float scale_to_store = max_value_thread / MAX_VALUE;
if (use_ue8m0) {
scale_to_store = ceil_to_ue8m0(scale_to_store);
}
// quant
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
// store
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
@@ -429,8 +401,7 @@ __global__ void masked_quant_per_token_per_block(
std::vector<paddle::Tensor> MaskedPerTokenQuant(
paddle::Tensor &input,
paddle::Tensor &recv_expert_count,
const int block_size,
const bool use_ue8m0) {
const int block_size) {
auto input_dim = input.dims();
const int num_local_expert = input_dim[0];
const int num_max_tokens_per_expert = input_dim[1];
@@ -468,8 +439,7 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(
hidden_size,
hidden_size_scale,
num_max_tokens_per_expert,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
masked_quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
@@ -481,8 +451,7 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(
hidden_size,
hidden_size_scale,
num_max_tokens_per_expert,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
@@ -493,13 +462,13 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(
PD_BUILD_STATIC_OP(per_token_quant)
.Inputs({"input"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.Attrs({"block_size: int"})
.SetKernelFn(PD_KERNEL(PerTokenQuant));
PD_BUILD_STATIC_OP(per_token_quant_padding)
.Inputs({"input"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.Attrs({"block_size: int"})
.SetKernelFn(PD_KERNEL(PerTokenQuantPadding))
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype));
@@ -507,5 +476,5 @@ PD_BUILD_STATIC_OP(per_token_quant_padding)
PD_BUILD_STATIC_OP(masked_per_token_quant)
.Inputs({"input", "recv_expert_count"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.Attrs({"block_size: int"})
.SetKernelFn(PD_KERNEL(MaskedPerTokenQuant));