diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index f78a1a4ce..d29e55286 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -284,13 +284,16 @@ std::vector EPMoeExpertDispatchFP8( const int token_nums_this_rank_padded); std::vector PerTokenQuant(paddle::Tensor& input, - const int block_size); + const int block_size, + const bool use_ue8m0); std::vector PerTokenQuantPadding(paddle::Tensor& input, - const int block_size); + const int block_size, + const bool use_ue8m0); std::vector MaskedPerTokenQuant( paddle::Tensor& input, paddle::Tensor& recv_expert_count, - const int block_size); + const int block_size, + const bool use_ue8m0); std::vector EPMoeExpertCombine( const paddle::Tensor& ffn_out, @@ -1234,12 +1237,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &PerTokenQuant, py::arg("input"), py::arg("block_size"), + py::arg("use_ue8m0") = false, "per token per block quant"); m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"), py::arg("block_size"), + py::arg("use_ue8m0") = false, "per token per block quant and padding transpose scale"); m.def("masked_per_token_quant", @@ -1247,6 +1252,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("input"), py::arg("recv_expert_count"), py::arg("block_size"), + py::arg("use_ue8m0") = false, "per token per block quant"); #ifdef ENABLE_MACHETE diff --git a/custom_ops/gpu_ops/per_token_quant_fp8.cu b/custom_ops/gpu_ops/per_token_quant_fp8.cu index bd783df81..047e95f74 100644 --- a/custom_ops/gpu_ops/per_token_quant_fp8.cu +++ b/custom_ops/gpu_ops/per_token_quant_fp8.cu @@ -16,6 +16,16 @@ constexpr float epsilon = 1e-10; +__device__ __forceinline__ float ceil_to_ue8m0(float s) { + int exp; + frexpf(s, &exp); + float pow2 = ldexpf(1.0f, exp - 1); + if (pow2 < s) { + pow2 = ldexpf(1.0f, exp); + } + return pow2; +} + template __global__ void quant_per_token_per_block( const T *input, @@ -24,7 +34,8 @@ __global__ void quant_per_token_per_block( const int token_num, const int hidden_size, const int hidden_size_scale, - const bool use_finegrained_range) { + const bool use_finegrained_range, + const bool use_ue8m0) { const int bid = blockIdx.x; const int tid = threadIdx.x; const int warp_id = tid / 32; @@ -83,11 +94,14 @@ __global__ void quant_per_token_per_block( } float scale_to_store = max_value_thread / MAX_VALUE; + if (use_ue8m0) { + scale_to_store = ceil_to_ue8m0(scale_to_store); + } // quant #pragma unroll for (int vid = 0; vid < NUM_PER_THREADS; vid++) { res_vec[vid] = static_cast( - load_vec_float[vid] * MAX_VALUE / max_value_thread); + load_vec_float[vid] / scale_to_store); } // store if (is_valid_data) @@ -102,7 +116,8 @@ __global__ void quant_per_token_per_block( } std::vector PerTokenQuant(paddle::Tensor &input, - const int block_size) { + const int block_size, + const bool use_ue8m0) { auto input_dim = input.dims(); const int token_num = input_dim[0]; const int hidden_size = input_dim[1]; @@ -132,7 +147,8 @@ std::vector PerTokenQuant(paddle::Tensor &input, token_num, hidden_size, hidden_size_scale, - use_finegrained_range); + use_finegrained_range, + use_ue8m0); break; case paddle::DataType::FLOAT16: quant_per_token_per_block<<>>( @@ -142,7 +158,8 @@ std::vector PerTokenQuant(paddle::Tensor &input, token_num, hidden_size, hidden_size_scale, - use_finegrained_range); + use_finegrained_range, + use_ue8m0); break; default: PD_THROW("Unsupported data type for PerTokenQuant"); @@ -159,7 +176,8 @@ __global__ void quant_per_token_per_block_padding( const int padded_token_num, const int hidden_size, const int hidden_size_scale, - const bool use_finegrained_range) { + const bool use_finegrained_range, + const bool use_ue8m0) { const int bid = blockIdx.x; const int tid = threadIdx.x; const int warp_id = tid / 32; @@ -209,11 +227,14 @@ __global__ void quant_per_token_per_block_padding( } float scale_to_store = max_value_thread / MAX_VALUE; + if (use_ue8m0) { + scale_to_store = ceil_to_ue8m0(scale_to_store); + } // quant #pragma unroll for (int vid = 0; vid < NUM_PER_THREADS; vid++) { res_vec[vid] = static_cast( - load_vec_float[vid] * MAX_VALUE / max_value_thread); + load_vec_float[vid] / scale_to_store); } // store Store( @@ -226,7 +247,8 @@ __global__ void quant_per_token_per_block_padding( } std::vector PerTokenQuantPadding(paddle::Tensor &input, - const int block_size) { + const int block_size, + const bool use_ue8m0) { using ScaleDtype = float; auto input_dim = input.dims(); @@ -269,7 +291,8 @@ std::vector PerTokenQuantPadding(paddle::Tensor &input, padded_token_num, hidden_size, hidden_size_scale, - use_finegrained_range); + use_finegrained_range, + use_ue8m0); break; case paddle::DataType::FLOAT16: quant_per_token_per_block_padding<<>>( @@ -280,7 +303,8 @@ std::vector PerTokenQuantPadding(paddle::Tensor &input, padded_token_num, hidden_size, hidden_size_scale, - use_finegrained_range); + use_finegrained_range, + use_ue8m0); break; default: PD_THROW("Unsupported data type for PerTokenQuant"); @@ -320,7 +344,8 @@ __global__ void masked_quant_per_token_per_block( const int hidden_size, const int hidden_size_scale, const int num_max_tokens_per_expert, - const bool use_finegrained_range) { + const bool use_finegrained_range, + const bool use_ue8m0) { const int bid = blockIdx.x; const int tid = threadIdx.x; const int warp_id = tid / 32; @@ -382,11 +407,14 @@ __global__ void masked_quant_per_token_per_block( } float scale_to_store = max_value_thread / MAX_VALUE; + if (use_ue8m0) { + scale_to_store = ceil_to_ue8m0(scale_to_store); + } // quant #pragma unroll for (int vid = 0; vid < NUM_PER_THREADS; vid++) { res_vec[vid] = static_cast( - load_vec_float[vid] * MAX_VALUE / max_value_thread); + load_vec_float[vid] / scale_to_store); } // store Store( @@ -401,7 +429,8 @@ __global__ void masked_quant_per_token_per_block( std::vector MaskedPerTokenQuant( paddle::Tensor &input, paddle::Tensor &recv_expert_count, - const int block_size) { + const int block_size, + const bool use_ue8m0) { auto input_dim = input.dims(); const int num_local_expert = input_dim[0]; const int num_max_tokens_per_expert = input_dim[1]; @@ -439,7 +468,8 @@ std::vector MaskedPerTokenQuant( hidden_size, hidden_size_scale, num_max_tokens_per_expert, - use_finegrained_range); + use_finegrained_range, + use_ue8m0); break; case paddle::DataType::FLOAT16: masked_quant_per_token_per_block<<>>( @@ -451,7 +481,8 @@ std::vector MaskedPerTokenQuant( hidden_size, hidden_size_scale, num_max_tokens_per_expert, - use_finegrained_range); + use_finegrained_range, + use_ue8m0); break; default: PD_THROW("Unsupported data type for PerTokenQuant"); @@ -462,13 +493,13 @@ std::vector MaskedPerTokenQuant( PD_BUILD_STATIC_OP(per_token_quant) .Inputs({"input"}) .Outputs({"output", "output_scale"}) - .Attrs({"block_size: int"}) + .Attrs({"block_size: int", "use_ue8m0: bool"}) .SetKernelFn(PD_KERNEL(PerTokenQuant)); PD_BUILD_STATIC_OP(per_token_quant_padding) .Inputs({"input"}) .Outputs({"output", "output_scale"}) - .Attrs({"block_size: int"}) + .Attrs({"block_size: int", "use_ue8m0: bool"}) .SetKernelFn(PD_KERNEL(PerTokenQuantPadding)) .SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype)); @@ -476,5 +507,5 @@ PD_BUILD_STATIC_OP(per_token_quant_padding) PD_BUILD_STATIC_OP(masked_per_token_quant) .Inputs({"input", "recv_expert_count"}) .Outputs({"output", "output_scale"}) - .Attrs({"block_size: int"}) + .Attrs({"block_size: int", "use_ue8m0: bool"}) .SetKernelFn(PD_KERNEL(MaskedPerTokenQuant)); diff --git a/tests/operators/test_masked_per_token_quant.py b/tests/operators/test_masked_per_token_quant.py index f0bdf525e..ccfe45681 100644 --- a/tests/operators/test_masked_per_token_quant.py +++ b/tests/operators/test_masked_per_token_quant.py @@ -23,7 +23,20 @@ import paddle from fastdeploy.model_executor.ops.gpu import masked_per_token_quant -def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size): +def ceil_to_ue8m0_paddle(x: paddle.Tensor): + """ + x > 0 + return 2 ^ ceil(log2(x)) + """ + # log2(x) + log2_x = paddle.log(x) / paddle.log(paddle.to_tensor(2.0, dtype=x.dtype)) + # ceil + ceil_log2_x = paddle.ceil(log2_x) + # 2^k + return paddle.pow(paddle.to_tensor(2.0, dtype=x.dtype), ceil_log2_x) + + +def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size, use_ue8m0): """ Paddle API implementation of masked_per_token_quant @@ -84,6 +97,9 @@ def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size): # Calculate scale scale = max_abs_val / MAX_VALUE + if use_ue8m0: + scale = ceil_to_ue8m0_paddle(scale) + # Quantize quanted_value = reshaped_input / scale @@ -120,10 +136,11 @@ class TestMaskedPerTokenQuant(unittest.TestCase): [self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype ) self.recv_expert_count = paddle.to_tensor([3, 2], dtype="int32") + self.use_ue8m0 = True # Get reference results from paddle implementation self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref( - self.input_tensor, self.recv_expert_count, self.block_size + self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0 ) def _mask_invalid_tokens(self, quanted_x, quanted_scale, recv_expert_count): @@ -149,7 +166,7 @@ class TestMaskedPerTokenQuant(unittest.TestCase): def test_masked_per_token_quant_basic(self): """Test basic functionality against CUDA kernel""" quanted_x_cuda, quanted_scale_cuda = masked_per_token_quant( - self.input_tensor, self.recv_expert_count, self.block_size + self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0 ) quanted_x_cuda_masked, quanted_scale_cuda_masked = self._mask_invalid_tokens( @@ -177,6 +194,28 @@ class TestMaskedPerTokenQuant(unittest.TestCase): self.assertLess(diff_val, 0.01, msg="Quantized values should be close") +class TestMaskedPerTokenQuantWithUe8m0Case1(TestMaskedPerTokenQuant): + """Test with float16 input""" + + def setUp(self) -> None: + paddle.seed(2024) + self.num_local_expert = 3 + self.num_max_tokens_per_expert = 6 + self.hidden_size = 512 + self.block_size = 128 + self.dtype = paddle.float16 + self.use_ue8m0 = True + + self.input_tensor = paddle.randn( + [self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype + ) + self.recv_expert_count = paddle.to_tensor([4, 2, 5], dtype="int32") + + self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref( + self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0 + ) + + class TestMaskedPerTokenQuantCase1(TestMaskedPerTokenQuant): """Test with float16 input""" @@ -187,6 +226,7 @@ class TestMaskedPerTokenQuantCase1(TestMaskedPerTokenQuant): self.hidden_size = 512 self.block_size = 128 self.dtype = paddle.float16 + self.use_ue8m0 = False self.input_tensor = paddle.randn( [self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype @@ -194,7 +234,29 @@ class TestMaskedPerTokenQuantCase1(TestMaskedPerTokenQuant): self.recv_expert_count = paddle.to_tensor([4, 2, 5], dtype="int32") self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref( - self.input_tensor, self.recv_expert_count, self.block_size + self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0 + ) + + +class TestMaskedPerTokenQuantWithUe8m0Case2(TestMaskedPerTokenQuant): + """Test with different hidden size""" + + def setUp(self) -> None: + paddle.seed(2024) + self.num_local_expert = 4 + self.num_max_tokens_per_expert = 8 + self.hidden_size = 384 # 3 * 128 + self.block_size = 128 + self.dtype = paddle.bfloat16 + self.use_ue8m0 = True + + self.input_tensor = paddle.randn( + [self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype + ) + self.recv_expert_count = paddle.to_tensor([6, 3, 7, 1], dtype="int32") + + self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref( + self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0 ) @@ -208,6 +270,7 @@ class TestMaskedPerTokenQuantCase2(TestMaskedPerTokenQuant): self.hidden_size = 384 # 3 * 128 self.block_size = 128 self.dtype = paddle.bfloat16 + self.use_ue8m0 = False self.input_tensor = paddle.randn( [self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype @@ -215,7 +278,29 @@ class TestMaskedPerTokenQuantCase2(TestMaskedPerTokenQuant): self.recv_expert_count = paddle.to_tensor([6, 3, 7, 1], dtype="int32") self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref( - self.input_tensor, self.recv_expert_count, self.block_size + self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0 + ) + + +class TestMaskedPerTokenQuantWithUe8m0Case3(TestMaskedPerTokenQuant): + """Test with all experts having max tokens""" + + def setUp(self) -> None: + paddle.seed(2024) + self.num_local_expert = 2 + self.num_max_tokens_per_expert = 4 + self.hidden_size = 256 + self.block_size = 128 + self.dtype = paddle.bfloat16 + self.use_ue8m0 = True + self.input_tensor = paddle.randn( + [self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype + ) + # All experts use all tokens + self.recv_expert_count = paddle.to_tensor([4, 4], dtype="int32") + + self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref( + self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0 ) @@ -229,7 +314,7 @@ class TestMaskedPerTokenQuantCase3(TestMaskedPerTokenQuant): self.hidden_size = 256 self.block_size = 128 self.dtype = paddle.bfloat16 - + self.use_ue8m0 = True self.input_tensor = paddle.randn( [self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype ) @@ -237,7 +322,7 @@ class TestMaskedPerTokenQuantCase3(TestMaskedPerTokenQuant): self.recv_expert_count = paddle.to_tensor([4, 4], dtype="int32") self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref( - self.input_tensor, self.recv_expert_count, self.block_size + self.input_tensor, self.recv_expert_count, self.block_size, self.use_ue8m0 ) @@ -250,7 +335,7 @@ class TestMaskedPerTokenQuantEdgeCases(unittest.TestCase): input_tensor = paddle.randn([2, 4, 256], dtype="bfloat16") recv_expert_count = paddle.to_tensor([0, 2], dtype="int32") # First expert has no tokens - quanted_x_ref, quanted_scale_ref = masked_per_token_quant_ref(input_tensor, recv_expert_count, 128) + quanted_x_ref, quanted_scale_ref = masked_per_token_quant_ref(input_tensor, recv_expert_count, 128, False) # First expert should be all zeros - convert to float32 for comparison expert_0_quanted = quanted_x_ref[0].astype("float32") diff --git a/tests/operators/test_per_token_quant.py b/tests/operators/test_per_token_quant.py index 23972ce53..7ba92f767 100644 --- a/tests/operators/test_per_token_quant.py +++ b/tests/operators/test_per_token_quant.py @@ -25,7 +25,20 @@ from fastdeploy.model_executor.ops.gpu import per_token_quant, per_token_quant_p paddle.seed(2024) -def per_token_quant_paddle(input_tensor, block_size): +def ceil_to_ue8m0_paddle(x: paddle.Tensor): + """ + x > 0 + return 2 ^ ceil(log2(x)) + """ + # log2(x) + log2_x = paddle.log(x) / paddle.log(paddle.to_tensor(2.0, dtype=x.dtype)) + # ceil + ceil_log2_x = paddle.ceil(log2_x) + # 2^k + return paddle.pow(paddle.to_tensor(2.0, dtype=x.dtype), ceil_log2_x) + + +def per_token_quant_paddle(input_tensor, block_size, use_ue8m0: bool = False): MAX_VALUE = 448.0 epsilon = 1e-10 @@ -33,7 +46,6 @@ def per_token_quant_paddle(input_tensor, block_size): token_num = input_shape[0] hidden_size = input_shape[1] - # According to https://github.com/PaddlePaddle/FastDeploy/pull/3659 padding_size = (block_size - hidden_size % block_size) % block_size padded_input = input_tensor @@ -48,6 +60,8 @@ def per_token_quant_paddle(input_tensor, block_size): max_abs_val = paddle.max(paddle.abs(reshaped_input), axis=-1, keepdim=True) max_abs_val = paddle.clip(max_abs_val, min=epsilon) scale = max_abs_val / MAX_VALUE + if use_ue8m0: + scale = ceil_to_ue8m0_paddle(scale) quanted_value = reshaped_input / scale @@ -61,8 +75,8 @@ def per_token_quant_paddle(input_tensor, block_size): return quanted_x, quanted_scale -def per_token_quant_padding_paddle(input_tensor, block_size, dtype): - quanted_x, intermediate_scale = per_token_quant_paddle(input_tensor, block_size) +def per_token_quant_padding_paddle(input_tensor, block_size, dtype, use_ue8m0): + quanted_x, intermediate_scale = per_token_quant_paddle(input_tensor, block_size, use_ue8m0) token_num = input_tensor.shape[0] tma_alignment_elements = 4 @@ -88,16 +102,16 @@ class TestPerTokenQuant(unittest.TestCase): self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) def test_per_token_quant(self): - paddle_output, paddle_output_scale = per_token_quant_paddle(self.input_tensor, self.block_size) - output, output_scale = per_token_quant(self.input_tensor, self.block_size) + for use_ue8m0 in [False, True]: + paddle_output, paddle_output_scale = per_token_quant_paddle(self.input_tensor, self.block_size, use_ue8m0) + output, output_scale = per_token_quant(self.input_tensor, self.block_size, use_ue8m0) - np.testing.assert_allclose(paddle_output_scale.numpy(), output_scale.numpy(), rtol=1e-6) + np.testing.assert_allclose(paddle_output_scale.numpy(), output_scale.numpy(), rtol=1e-6) - output_rel_diff = paddle.mean( - paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32)) - ) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32))) - - assert output_rel_diff < 0.001 + output_rel_diff = paddle.mean( + paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32)) + ) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32))) + assert output_rel_diff < 0.001 class TestPerTokenQuantCase1(TestPerTokenQuant): @@ -136,24 +150,25 @@ class TestPerTokenQuantPadding(TestPerTokenQuant): self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) def test_per_token_quant_padding(self): - paddle_output, paddle_output_scale = per_token_quant_padding_paddle( - self.input_tensor, self.block_size, self.dtype - ) - output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size) + for use_ue8m0 in [False, True]: + paddle_output, paddle_output_scale = per_token_quant_padding_paddle( + self.input_tensor, self.block_size, self.dtype, use_ue8m0 + ) + output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size, use_ue8m0) - self.assertEqual(paddle_output_scale.shape, output_scale.shape) - np.testing.assert_allclose( - paddle_output_scale[0 : self.token_num].numpy(), - output_scale[0 : self.token_num].numpy(), - rtol=1e-5, - atol=1e-5, - ) + self.assertEqual(paddle_output_scale.shape, output_scale.shape) + np.testing.assert_allclose( + paddle_output_scale[0 : self.token_num].numpy(), + output_scale[0 : self.token_num].numpy(), + rtol=1e-5, + atol=1e-5, + ) - output_rel_diff = paddle.mean( - paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32)) - ) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32)) + 1e-9) + output_rel_diff = paddle.mean( + paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32)) + ) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32)) + 1e-9) - assert output_rel_diff < 0.001 + assert output_rel_diff < 0.001 class TestPerTokenQuantPaddingCase1(TestPerTokenQuantPadding):