Revert "[Feature] add ue8m0 for per_token_quant_fp8 (#5563)" (#5611)

This reverts commit 73e1d6aa90.
This commit is contained in:
Yuanle Liu
2025-12-17 13:59:06 +08:00
committed by GitHub
parent 21fa2baa51
commit cdc0004894
4 changed files with 56 additions and 193 deletions

View File

@@ -284,16 +284,13 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const int token_nums_this_rank_padded);
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
const int block_size,
const bool use_ue8m0);
const int block_size);
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
const int block_size,
const bool use_ue8m0);
const int block_size);
std::vector<paddle::Tensor> MaskedPerTokenQuant(
paddle::Tensor& input,
paddle::Tensor& recv_expert_count,
const int block_size,
const bool use_ue8m0);
const int block_size);
std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::Tensor& ffn_out,
@@ -1235,14 +1232,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&PerTokenQuant,
py::arg("input"),
py::arg("block_size"),
py::arg("use_ue8m0") = false,
"per token per block quant");
m.def("per_token_quant_padding",
&PerTokenQuantPadding,
py::arg("input"),
py::arg("block_size"),
py::arg("use_ue8m0") = false,
"per token per block quant and padding transpose scale");
m.def("masked_per_token_quant",
@@ -1250,7 +1245,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("input"),
py::arg("recv_expert_count"),
py::arg("block_size"),
py::arg("use_ue8m0") = false,
"per token per block quant");
#ifdef ENABLE_MACHETE

View File

@@ -16,16 +16,6 @@
constexpr float epsilon = 1e-10;
__device__ __forceinline__ float ceil_to_ue8m0(float s) {
int exp;
frexpf(s, &exp);
float pow2 = ldexpf(1.0f, exp - 1);
if (pow2 < s) {
pow2 = ldexpf(1.0f, exp);
}
return pow2;
}
template <typename T>
__global__ void quant_per_token_per_block(
const T *input,
@@ -34,8 +24,7 @@ __global__ void quant_per_token_per_block(
const int token_num,
const int hidden_size,
const int hidden_size_scale,
const bool use_finegrained_range,
const bool use_ue8m0) {
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -94,14 +83,11 @@ __global__ void quant_per_token_per_block(
}
float scale_to_store = max_value_thread / MAX_VALUE;
if (use_ue8m0) {
scale_to_store = ceil_to_ue8m0(scale_to_store);
}
// quant
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
// store
if (is_valid_data)
@@ -116,8 +102,7 @@ __global__ void quant_per_token_per_block(
}
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int block_size,
const bool use_ue8m0) {
const int block_size) {
auto input_dim = input.dims();
const int token_num = input_dim[0];
const int hidden_size = input_dim[1];
@@ -147,8 +132,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
@@ -158,8 +142,7 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
@@ -176,8 +159,7 @@ __global__ void quant_per_token_per_block_padding(
const int padded_token_num,
const int hidden_size,
const int hidden_size_scale,
const bool use_finegrained_range,
const bool use_ue8m0) {
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -227,14 +209,11 @@ __global__ void quant_per_token_per_block_padding(
}
float scale_to_store = max_value_thread / MAX_VALUE;
if (use_ue8m0) {
scale_to_store = ceil_to_ue8m0(scale_to_store);
}
// quant
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
// store
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
@@ -247,8 +226,7 @@ __global__ void quant_per_token_per_block_padding(
}
std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
const int block_size,
const bool use_ue8m0) {
const int block_size) {
using ScaleDtype = float;
auto input_dim = input.dims();
@@ -291,8 +269,7 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
@@ -303,8 +280,7 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input,
padded_token_num,
hidden_size,
hidden_size_scale,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
@@ -344,8 +320,7 @@ __global__ void masked_quant_per_token_per_block(
const int hidden_size,
const int hidden_size_scale,
const int num_max_tokens_per_expert,
const bool use_finegrained_range,
const bool use_ue8m0) {
const bool use_finegrained_range) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
@@ -407,14 +382,11 @@ __global__ void masked_quant_per_token_per_block(
}
float scale_to_store = max_value_thread / MAX_VALUE;
if (use_ue8m0) {
scale_to_store = ceil_to_ue8m0(scale_to_store);
}
// quant
#pragma unroll
for (int vid = 0; vid < NUM_PER_THREADS; vid++) {
res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(
load_vec_float[vid] / scale_to_store);
load_vec_float[vid] * MAX_VALUE / max_value_thread);
}
// store
Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(
@@ -429,8 +401,7 @@ __global__ void masked_quant_per_token_per_block(
std::vector<paddle::Tensor> MaskedPerTokenQuant(
paddle::Tensor &input,
paddle::Tensor &recv_expert_count,
const int block_size,
const bool use_ue8m0) {
const int block_size) {
auto input_dim = input.dims();
const int num_local_expert = input_dim[0];
const int num_max_tokens_per_expert = input_dim[1];
@@ -468,8 +439,7 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(
hidden_size,
hidden_size_scale,
num_max_tokens_per_expert,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
case paddle::DataType::FLOAT16:
masked_quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
@@ -481,8 +451,7 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(
hidden_size,
hidden_size_scale,
num_max_tokens_per_expert,
use_finegrained_range,
use_ue8m0);
use_finegrained_range);
break;
default:
PD_THROW("Unsupported data type for PerTokenQuant");
@@ -493,13 +462,13 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(
PD_BUILD_STATIC_OP(per_token_quant)
.Inputs({"input"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.Attrs({"block_size: int"})
.SetKernelFn(PD_KERNEL(PerTokenQuant));
PD_BUILD_STATIC_OP(per_token_quant_padding)
.Inputs({"input"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.Attrs({"block_size: int"})
.SetKernelFn(PD_KERNEL(PerTokenQuantPadding))
.SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype));
@@ -507,5 +476,5 @@ PD_BUILD_STATIC_OP(per_token_quant_padding)
PD_BUILD_STATIC_OP(masked_per_token_quant)
.Inputs({"input", "recv_expert_count"})
.Outputs({"output", "output_scale"})
.Attrs({"block_size: int", "use_ue8m0: bool"})
.Attrs({"block_size: int"})
.SetKernelFn(PD_KERNEL(MaskedPerTokenQuant));

View File

@@ -23,20 +23,7 @@ import paddle
from fastdeploy.model_executor.ops.gpu import masked_per_token_quant
def ceil_to_ue8m0_paddle(x: paddle.Tensor):
"""
x > 0
return 2 ^ ceil(log2(x))
"""
# log2(x)
log2_x = paddle.log(x) / paddle.log(paddle.to_tensor(2.0, dtype=x.dtype))
# ceil
ceil_log2_x = paddle.ceil(log2_x)
# 2^k
return paddle.pow(paddle.to_tensor(2.0, dtype=x.dtype), ceil_log2_x)
def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size, use_ue8m0):
def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size):
"""
Paddle API implementation of masked_per_token_quant
@@ -97,9 +84,6 @@ def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size, use_
# Calculate scale
scale = max_abs_val / MAX_VALUE
if use_ue8m0:
scale = ceil_to_ue8m0_paddle(scale)
# Quantize
quanted_value = reshaped_input / scale
@@ -136,11 +120,10 @@ class TestMaskedPerTokenQuant(unittest.TestCase):
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
self.recv_expert_count = paddle.to_tensor([3, 2], dtype="int32")
self.use_ue8m0 = True
# Get reference results from paddle implementation
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0
self.input_tensor, self.recv_expert_count, self.block_size
)
def _mask_invalid_tokens(self, quanted_x, quanted_scale, recv_expert_count):
@@ -166,7 +149,7 @@ class TestMaskedPerTokenQuant(unittest.TestCase):
def test_masked_per_token_quant_basic(self):
"""Test basic functionality against CUDA kernel"""
quanted_x_cuda, quanted_scale_cuda = masked_per_token_quant(
self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0
self.input_tensor, self.recv_expert_count, self.block_size
)
quanted_x_cuda_masked, quanted_scale_cuda_masked = self._mask_invalid_tokens(
@@ -194,28 +177,6 @@ class TestMaskedPerTokenQuant(unittest.TestCase):
self.assertLess(diff_val, 0.01, msg="Quantized values should be close")
class TestMaskedPerTokenQuantWithUe8m0Case1(TestMaskedPerTokenQuant):
"""Test with float16 input"""
def setUp(self) -> None:
paddle.seed(2024)
self.num_local_expert = 3
self.num_max_tokens_per_expert = 6
self.hidden_size = 512
self.block_size = 128
self.dtype = paddle.float16
self.use_ue8m0 = True
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
self.recv_expert_count = paddle.to_tensor([4, 2, 5], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0
)
class TestMaskedPerTokenQuantCase1(TestMaskedPerTokenQuant):
"""Test with float16 input"""
@@ -226,7 +187,6 @@ class TestMaskedPerTokenQuantCase1(TestMaskedPerTokenQuant):
self.hidden_size = 512
self.block_size = 128
self.dtype = paddle.float16
self.use_ue8m0 = False
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
@@ -234,29 +194,7 @@ class TestMaskedPerTokenQuantCase1(TestMaskedPerTokenQuant):
self.recv_expert_count = paddle.to_tensor([4, 2, 5], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0
)
class TestMaskedPerTokenQuantWithUe8m0Case2(TestMaskedPerTokenQuant):
"""Test with different hidden size"""
def setUp(self) -> None:
paddle.seed(2024)
self.num_local_expert = 4
self.num_max_tokens_per_expert = 8
self.hidden_size = 384 # 3 * 128
self.block_size = 128
self.dtype = paddle.bfloat16
self.use_ue8m0 = True
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
self.recv_expert_count = paddle.to_tensor([6, 3, 7, 1], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0
self.input_tensor, self.recv_expert_count, self.block_size
)
@@ -270,7 +208,6 @@ class TestMaskedPerTokenQuantCase2(TestMaskedPerTokenQuant):
self.hidden_size = 384 # 3 * 128
self.block_size = 128
self.dtype = paddle.bfloat16
self.use_ue8m0 = False
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
@@ -278,29 +215,7 @@ class TestMaskedPerTokenQuantCase2(TestMaskedPerTokenQuant):
self.recv_expert_count = paddle.to_tensor([6, 3, 7, 1], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0
)
class TestMaskedPerTokenQuantWithUe8m0Case3(TestMaskedPerTokenQuant):
"""Test with all experts having max tokens"""
def setUp(self) -> None:
paddle.seed(2024)
self.num_local_expert = 2
self.num_max_tokens_per_expert = 4
self.hidden_size = 256
self.block_size = 128
self.dtype = paddle.bfloat16
self.use_ue8m0 = True
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
# All experts use all tokens
self.recv_expert_count = paddle.to_tensor([4, 4], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0
self.input_tensor, self.recv_expert_count, self.block_size
)
@@ -314,7 +229,7 @@ class TestMaskedPerTokenQuantCase3(TestMaskedPerTokenQuant):
self.hidden_size = 256
self.block_size = 128
self.dtype = paddle.bfloat16
self.use_ue8m0 = True
self.input_tensor = paddle.randn(
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
)
@@ -322,7 +237,7 @@ class TestMaskedPerTokenQuantCase3(TestMaskedPerTokenQuant):
self.recv_expert_count = paddle.to_tensor([4, 4], dtype="int32")
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0
self.input_tensor, self.recv_expert_count, self.block_size
)
@@ -335,7 +250,7 @@ class TestMaskedPerTokenQuantEdgeCases(unittest.TestCase):
input_tensor = paddle.randn([2, 4, 256], dtype="bfloat16")
recv_expert_count = paddle.to_tensor([0, 2], dtype="int32") # First expert has no tokens
quanted_x_ref, quanted_scale_ref = masked_per_token_quant_ref(input_tensor, recv_expert_count, 128, False)
quanted_x_ref, quanted_scale_ref = masked_per_token_quant_ref(input_tensor, recv_expert_count, 128)
# First expert should be all zeros - convert to float32 for comparison
expert_0_quanted = quanted_x_ref[0].astype("float32")

View File

@@ -25,20 +25,7 @@ from fastdeploy.model_executor.ops.gpu import per_token_quant, per_token_quant_p
paddle.seed(2024)
def ceil_to_ue8m0_paddle(x: paddle.Tensor):
"""
x > 0
return 2 ^ ceil(log2(x))
"""
# log2(x)
log2_x = paddle.log(x) / paddle.log(paddle.to_tensor(2.0, dtype=x.dtype))
# ceil
ceil_log2_x = paddle.ceil(log2_x)
# 2^k
return paddle.pow(paddle.to_tensor(2.0, dtype=x.dtype), ceil_log2_x)
def per_token_quant_paddle(input_tensor, block_size, use_ue8m0: bool = False):
def per_token_quant_paddle(input_tensor, block_size):
MAX_VALUE = 448.0
epsilon = 1e-10
@@ -46,6 +33,7 @@ def per_token_quant_paddle(input_tensor, block_size, use_ue8m0: bool = False):
token_num = input_shape[0]
hidden_size = input_shape[1]
# According to https://github.com/PaddlePaddle/FastDeploy/pull/3659
padding_size = (block_size - hidden_size % block_size) % block_size
padded_input = input_tensor
@@ -60,8 +48,6 @@ def per_token_quant_paddle(input_tensor, block_size, use_ue8m0: bool = False):
max_abs_val = paddle.max(paddle.abs(reshaped_input), axis=-1, keepdim=True)
max_abs_val = paddle.clip(max_abs_val, min=epsilon)
scale = max_abs_val / MAX_VALUE
if use_ue8m0:
scale = ceil_to_ue8m0_paddle(scale)
quanted_value = reshaped_input / scale
@@ -75,8 +61,8 @@ def per_token_quant_paddle(input_tensor, block_size, use_ue8m0: bool = False):
return quanted_x, quanted_scale
def per_token_quant_padding_paddle(input_tensor, block_size, dtype, use_ue8m0):
quanted_x, intermediate_scale = per_token_quant_paddle(input_tensor, block_size, use_ue8m0)
def per_token_quant_padding_paddle(input_tensor, block_size, dtype):
quanted_x, intermediate_scale = per_token_quant_paddle(input_tensor, block_size)
token_num = input_tensor.shape[0]
tma_alignment_elements = 4
@@ -102,16 +88,16 @@ class TestPerTokenQuant(unittest.TestCase):
self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype)
def test_per_token_quant(self):
for use_ue8m0 in [False, True]:
paddle_output, paddle_output_scale = per_token_quant_paddle(self.input_tensor, self.block_size, use_ue8m0)
output, output_scale = per_token_quant(self.input_tensor, self.block_size, use_ue8m0)
paddle_output, paddle_output_scale = per_token_quant_paddle(self.input_tensor, self.block_size)
output, output_scale = per_token_quant(self.input_tensor, self.block_size)
np.testing.assert_allclose(paddle_output_scale.numpy(), output_scale.numpy(), rtol=1e-6)
np.testing.assert_allclose(paddle_output_scale.numpy(), output_scale.numpy(), rtol=1e-6)
output_rel_diff = paddle.mean(
paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32))
) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32)))
assert output_rel_diff < 0.001
output_rel_diff = paddle.mean(
paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32))
) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32)))
assert output_rel_diff < 0.001
class TestPerTokenQuantCase1(TestPerTokenQuant):
@@ -150,25 +136,24 @@ class TestPerTokenQuantPadding(TestPerTokenQuant):
self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype)
def test_per_token_quant_padding(self):
for use_ue8m0 in [False, True]:
paddle_output, paddle_output_scale = per_token_quant_padding_paddle(
self.input_tensor, self.block_size, self.dtype, use_ue8m0
)
output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size, use_ue8m0)
paddle_output, paddle_output_scale = per_token_quant_padding_paddle(
self.input_tensor, self.block_size, self.dtype
)
output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size)
self.assertEqual(paddle_output_scale.shape, output_scale.shape)
np.testing.assert_allclose(
paddle_output_scale[0 : self.token_num].numpy(),
output_scale[0 : self.token_num].numpy(),
rtol=1e-5,
atol=1e-5,
)
self.assertEqual(paddle_output_scale.shape, output_scale.shape)
np.testing.assert_allclose(
paddle_output_scale[0 : self.token_num].numpy(),
output_scale[0 : self.token_num].numpy(),
rtol=1e-5,
atol=1e-5,
)
output_rel_diff = paddle.mean(
paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32))
) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32)) + 1e-9)
output_rel_diff = paddle.mean(
paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32))
) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32)) + 1e-9)
assert output_rel_diff < 0.001
assert output_rel_diff < 0.001
class TestPerTokenQuantPaddingCase1(TestPerTokenQuantPadding):