[XPU] refactor moe ffn (#5501)

- remove BKCL_DISPATCH_ALL_GATHER
- support sparse mode
- support moe quant_method
This commit is contained in:
zhupengyang
2025-12-18 14:14:05 +08:00
committed by GitHub
parent d0a7834a17
commit 8735cb5045
12 changed files with 397 additions and 127 deletions

View File

@@ -57,7 +57,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
const int64_t ep_size = 1;
const int64_t ep_rank = 0;
if (std::is_same<TY, int8_t>::value) {
if (std::is_same<TY, int8_t>::value && !std::is_same<TX, int8_t>::value) {
permute_input =
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
if (token_nums_this_rank > 0) {
@@ -99,7 +99,11 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
block_num,
ep_size,
ep_rank,
token_nums_this_rank);
token_nums_this_rank,
std::is_same<TX, int8_t>::value
? input_scales.get_ptr()->data<float>()
: nullptr,
expand_input_scales.data<float>());
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
}
}
@@ -138,10 +142,12 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
quant_method != "w4a8") {
APPLY_KERNEL(paddle::bfloat16, paddle::bfloat16);
} else if (input_dtype == paddle::DataType::INT8) {
APPLY_KERNEL(int8_t, int8_t);
} else {
PD_THROW("EPMoeExpertDispatch not support input_dtype=",
static_cast<int>(input_dtype),
"quant_method=",
", quant_method=",
quant_method);
return {};
}

View File

@@ -28,7 +28,6 @@
#endif
XPU_DECLARE_BOOL(MOE_FFN_USE_DENSE_INPUT, false);
XPU_DECLARE_BOOL(BKCL_DISPATCH_ALL_GATHER, false);
namespace xftblock = baidu::xpu::xftblock;
namespace api = baidu::xpu::api;
@@ -36,6 +35,7 @@ namespace api = baidu::xpu::api;
template <typename TX1, typename TX2, typename TW, typename TGEMM>
void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
xftblock::Tensor* token_num_info,
xftblock::Tensor* token_num_lod,
xftblock::Tensor* ffn1_weight,
xftblock::Tensor* ffn2_weight,
xftblock::Tensor* ffn1_bias,
@@ -44,7 +44,8 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
float* ffn2_act_scale,
TX2* ffn2_shift,
TX2* ffn2_smooth,
const int hadamard_blocksize) {
const int hadamard_blocksize,
const int64_t group_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
@@ -68,11 +69,11 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
ffn1_weight,
&ffn1_out,
ffn1_bias,
is_padding_input ? nullptr : token_num_info,
token_num_lod,
is_padding_input ? token_num_info : nullptr,
expert_num,
1, // moe_topk
0, // group_size
group_size,
ffn1_out_shape.size() == 2 ? xftblock::MoeFCInputMode::DENSE
: xftblock::MoeFCInputMode::SPARSE);
PD_CHECK(ret == 0);
@@ -81,13 +82,25 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
auto swiglu_out_shape = ffn1_out_shape;
swiglu_out_shape[swiglu_out_shape.size() - 1] /= 2;
xftblock::Tensor swiglu_out(rt_guard, xftblock_tx2, swiglu_out_shape);
ret = api::fast_swiglu<TX2>(xpu_ctx->x_context(),
ffn1_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
{token_num, inter_dim},
1,
true);
PD_CHECK(ret == 0);
if (is_padding_input) {
ret = infer_ops::swiglu_unt2<TX2>(xpu_ctx->x_context(),
ffn1_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
token_num,
inter_dim,
expert_num,
token_num_info->data<int>(),
true);
PD_CHECK(ret == 0);
} else {
ret = api::fast_swiglu<TX2>(xpu_ctx->x_context(),
ffn1_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
{token_num, inter_dim},
1,
true);
PD_CHECK(ret == 0);
}
// TODO(mayang02): use fusion_smooth_transform
if (ffn2_shift != nullptr) {
ret = api::broadcast_add<TX2>(xpu_ctx->x_context(),
@@ -109,14 +122,17 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
}
if (hadamard_blocksize > 0) {
ret = infer_ops::fast_walsh_transform<TX2>(xpu_ctx->x_context(),
swiglu_out.data<TX2>(),
nullptr,
nullptr,
swiglu_out.mutable_data<TX2>(),
hadamard_blocksize,
token_num,
outer_dim);
ret = infer_ops::fast_walsh_transform<TX2>(
xpu_ctx->x_context(),
swiglu_out.data<TX2>(),
nullptr,
nullptr,
swiglu_out.mutable_data<TX2>(),
hadamard_blocksize,
token_num,
outer_dim,
is_padding_input ? token_num / expert_num : 0,
is_padding_input ? token_num_lod->data<int>() : nullptr);
PD_CHECK(ret == 0);
}
@@ -131,11 +147,11 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
ffn2_weight,
ffn2_out,
nullptr,
is_padding_input ? nullptr : token_num_info,
token_num_lod,
is_padding_input ? token_num_info : nullptr,
expert_num,
1, // moe_topk
0, // group_size
group_size,
ffn1_out_shape.size() == 2
? xftblock::MoeFCInputMode::DENSE
: xftblock::MoeFCInputMode::SPARSE); // bias_mode
@@ -143,23 +159,25 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
}
static void convert_to_lod(xftblock::XFTContext* xctx,
xftblock::Tensor* token_num_info) {
auto rt_guard = xctx->get_rt_guard();
auto ctx = xctx->get_context();
const int expert_num = token_num_info->numel();
xftblock::Tensor tokens_num_lod(
rt_guard, xftblock::DataType::DT_INT32, {expert_num + 1});
int ret = api::constant(ctx, tokens_num_lod.data<int>(), expert_num + 1, 0);
PD_CHECK(ret == 0);
ret = api::cumsum<int>(ctx,
token_num_info->data<int>(),
tokens_num_lod.data<int>() + 1,
{expert_num},
false,
false,
0);
PD_CHECK(ret == 0);
*token_num_info = std::move(tokens_num_lod);
xftblock::Tensor* token_num_info,
xftblock::Tensor* token_num_lod,
int expert_num) {
if (expert_num == token_num_info->numel()) {
auto rt_guard = xctx->get_rt_guard();
auto ctx = xctx->get_context();
int ret = api::constant(ctx, token_num_lod->data<int>(), expert_num + 1, 0);
PD_CHECK(ret == 0);
ret = api::cumsum<int>(ctx,
token_num_info->data<int>(),
token_num_lod->data<int>() + 1,
{expert_num},
false,
false,
0);
PD_CHECK(ret == 0);
} else {
*token_num_lod = std::move(*token_num_info);
}
}
template <typename TX1, typename TX2, typename TW>
@@ -196,6 +214,12 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
int inter_dim = ffn1_w_shape[1];
int outer_dim = inter_dim / 2;
bool is_padding_input = input_shape.size() == 3;
int64_t group_size = 0;
if (ffn1_weight_scale.get_ptr() &&
ffn1_weight_scale.get_ptr()->numel() > expert_num * inter_dim) {
group_size = hidden_dim / (ffn1_weight_scale.get_ptr()->numel() /
expert_num / inter_dim);
}
if (is_padding_input) {
PD_CHECK(input_shape[0] == expert_num);
PD_CHECK(token_num_info.numel() == expert_num,
@@ -206,7 +230,9 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
expert_num);
}
bool is_w4 = quant_method == "w4a8" || quant_method == "weight_only_int4";
bool is_w4 = quant_method == "w_channelwise_int4_a_tokenwise_int15" ||
quant_method == "w_channelwise_int4_a_expertwise_int8" ||
quant_method == "w_channelwise_int4_a_tokenwise_int8";
auto xftblock_tx1 = xftblock::DataTypeToEnum<XPU_TX1>::value;
auto xftblock_tx2 = xftblock::DataTypeToEnum<XPU_TX2>::value;
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
@@ -256,6 +282,8 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xftblock::Tensor xtoken_num_info(const_cast<int*>(token_num_info.data<int>()),
xftblock::DataType::DT_INT32,
token_num_info.shape());
xftblock::Tensor xtoken_num_lod(
rt_guard, xftblock::DataType::DT_INT32, {expert_num + 1});
XPU_TX2* shift_data = nullptr;
XPU_TX2* smooth_data = nullptr;
if (ffn2_shift.get_ptr()) {
@@ -272,9 +300,10 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xftblock::Tensor xffn2_out;
paddle::Tensor ffn1_in_dense;
paddle::Tensor ffn1_in_scale_per_token;
convert_to_lod(&xctx, &xtoken_num_info, &xtoken_num_lod, expert_num);
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
convert_to_lod(&xctx, &xtoken_num_info);
if (quant_method == "w4a8") {
if (quant_method == "w_channelwise_int4_a_expertwise_int8" ||
quant_method == "w_channelwise_int4_a_tokenwise_int8") {
ffn1_in_scale_per_token = paddle::empty(
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
ffn1_in_dense = paddle::empty({valid_token_num, hidden_dim},
@@ -292,7 +321,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
ffn1_act_scale_data,
ffn1_in_scale_per_token.data<float>(),
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
1,
@@ -302,7 +331,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
reinterpret_cast<const int8_t*>(ffn_in.data<int8_t>()),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
@@ -313,7 +342,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
ffn1_act_scale_data,
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
ffn1_in_scale_per_token.data<float>(),
expert_num,
@@ -324,6 +353,64 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
input_shape[1]);
PD_CHECK(ret == 0);
}
} else if (quant_method == "w_channelwise_int8_a_expertwise_int8" ||
quant_method == "w_channelwise_int8_a_tokenwise_int8") {
ffn1_in_scale_per_token = paddle::empty(
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
ffn1_in_dense = paddle::empty({valid_token_num, hidden_dim},
paddle::DataType::INT8,
ffn_in.place());
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<int8_t>(),
nullptr,
ffn1_in_scale_per_token.data<float>(),
xftblock::DataType::DT_INT8,
{valid_token_num, hidden_dim});
if (std::is_same<XPU_TX1, int8_t>::value) {
PD_CHECK(ffn1_act_scale_data != nullptr,
"need ffn1_act_scale for x int8 per expert input");
ret = infer_ops::sequence_unpad<float, int>(
xpu_ctx->x_context(),
ffn1_act_scale_data,
ffn1_in_scale_per_token.data<float>(),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
1,
true);
PD_CHECK(ret == 0);
ret = infer_ops::sequence_unpad<int8_t, int>(
xpu_ctx->x_context(),
reinterpret_cast<const int8_t*>(ffn_in.data<int8_t>()),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
true);
PD_CHECK(ret == 0);
} else {
auto ffn1_in_unpad = paddle::empty(
{valid_token_num, hidden_dim}, ffn_in.dtype(), ffn_in.place());
ret = infer_ops::sequence_unpad<XPU_TX1, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
reinterpret_cast<XPU_TX1*>(ffn1_in_unpad.data<TX1>()),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
true);
PD_CHECK(ret == 0);
ret = infer_ops::quant2d_per_token<XPU_TX1, float, int8_t>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn1_in_unpad.data<TX1>()),
nullptr,
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
ffn1_in_scale_per_token.data<float>(),
valid_token_num,
hidden_dim);
PD_CHECK(ret == api::SUCCESS);
}
} else {
ffn1_in_dense = paddle::empty(
{valid_token_num, hidden_dim}, ffn_in.dtype(), ffn_in.place());
@@ -336,7 +423,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
reinterpret_cast<XPU_TX1*>(xffn1_in.data<XPU_TX1>()),
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
@@ -345,31 +432,6 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
}
xffn2_out =
xftblock::Tensor(rt_guard, xftblock_tx2, {valid_token_num, hidden_dim});
} else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input &&
quant_method == "w4a8") {
convert_to_lod(&xctx, &xtoken_num_info);
ffn1_in_scale_per_token = paddle::empty(
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
ffn1_in_dense = paddle::empty(
{valid_token_num, hidden_dim}, paddle::DataType::INT8, ffn_in.place());
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<int8_t>(),
nullptr,
ffn1_in_scale_per_token.data<float>(),
xftblock::DataType::DT_INT8,
{valid_token_num, hidden_dim});
ret = infer_ops::quant2d_per_expert<XPU_TX1>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
ffn1_act_scale_data,
xtoken_num_info.data<int>(),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
ffn1_in_scale_per_token.data<float>(),
expert_num,
valid_token_num,
hidden_dim);
PD_CHECK(ret == 0);
xffn2_out =
xftblock::Tensor(ffn2_out.data<TX2>(), xftblock_tx2, input_shape);
} else {
xffn1_in = xftblock::Tensor(const_cast<TX1*>(ffn_in.data<TX1>()),
nullptr,
@@ -383,6 +445,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
#define FFN_IMPL(TX1, TX2, TW, TGEMM) \
MoeExpertFFNImpl<TX1, TX2, TW, TGEMM>(&xffn1_in, \
&xtoken_num_info, \
&xtoken_num_lod, \
&xffn1_w, \
&xffn2_w, \
xffn1_bias.get(), \
@@ -391,31 +454,32 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
ffn2_act_scale_data, \
shift_data, \
smooth_data, \
hadamard_blocksize)
if (quant_method == "weight_only_int8") {
static const char* xft_moe_fc_wint8_tgemm =
std::getenv("XFT_MOE_FC_WINT8_TGEMM");
if (xft_moe_fc_wint8_tgemm != nullptr) {
if (std::string(xft_moe_fc_wint8_tgemm) == "INT8") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, int8_wo_t);
} else if (std::string(xft_moe_fc_wint8_tgemm) == "FLOAT16") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float16);
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float);
}
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float);
}
} else if (quant_method == "weight_only_int4") {
hadamard_blocksize, \
group_size)
if (quant_method == "w_channelwise_int8_a_float32") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float);
} else if (quant_method == "w_channelwise_int8_a_tokenwise_float16") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float16);
} else if (quant_method == "w_channelwise_int8_a_tokenwise_int15") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, int8_wo_t);
} else if (quant_method == "w_channelwise_int4_a_tokenwise_int15") {
FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int15);
} else if (quant_method == "w4a8") {
} else if (quant_method == "w_channelwise_int4_a_expertwise_int8" ||
quant_method == "w_channelwise_int4_a_tokenwise_int8") {
// a8: per expert
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8);
} else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input) {
FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8);
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int8);
}
} else if (quant_method == "w_channelwise_int8_a_expertwise_int8" ||
quant_method == "w_channelwise_int8_a_tokenwise_int8") {
// a8: per expert
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
FFN_IMPL(int8_t, XPU_TX2, int8_t, int8_t);
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, int8_t);
}
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, XPU_TW, float);
}
@@ -425,7 +489,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
const_cast<XPU_TX2*>(xffn2_out.data<XPU_TX2>()),
reinterpret_cast<XPU_TX2*>(ffn2_out.data<TX2>()),
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
input_shape[0],
input_shape[1],
input_shape[2],

View File

@@ -143,6 +143,8 @@ std::vector<paddle::Tensor> WeightOnlyLinear(
const int arch,
const int group_size);
std::vector<paddle::Tensor> Quant2dPerToken(const paddle::Tensor& x);
std::vector<paddle::Tensor> MoeEPCombine(const paddle::Tensor& ffn_out,
const paddle::Tensor& moe_index,
const paddle::Tensor& weights,
@@ -1275,6 +1277,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("arch"),
py::arg("group_size") = -1);
m.def(
"quant2d_per_token", &Quant2dPerToken, py::arg("x"), "quant x per token");
m.def("xpu_moe_layer",
&MoeLayer,
py::arg("x"),

View File

@@ -0,0 +1,88 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <core/check.h>
#include <core/context.h>
#include <core/param.h>
#include <infer_ops.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#include "utility/env.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace xftblock = baidu::xpu::xftblock;
namespace api = baidu::xpu::api;
template <typename TX>
std::vector<paddle::Tensor> Quant2dPerTokenKernel(const paddle::Tensor& x) {
using XPU_TX = typename XPUTypeTrait<TX>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();
auto input_shape = x.shape();
auto x_scale =
paddle::empty({input_shape[0]}, paddle::DataType::FLOAT32, x.place());
auto quant_x = paddle::empty(
{input_shape[0], input_shape[1]}, paddle::DataType::INT8, x.place());
if (input_shape[0] > 0) {
int ret = infer_ops::quant2d_per_token<XPU_TX, float, int8_t>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(x.data<TX>()),
nullptr,
reinterpret_cast<int8_t*>(quant_x.data<int8_t>()),
reinterpret_cast<float*>(x_scale.data<float>()),
input_shape[0],
input_shape[1]);
PD_CHECK(ret == api::SUCCESS);
}
return {quant_x, x_scale};
}
std::vector<paddle::Tensor> Quant2dPerToken(const paddle::Tensor& x) {
const auto x_type = x.dtype();
if (x_type == paddle::DataType::BFLOAT16) {
return Quant2dPerTokenKernel<paddle::bfloat16>(x);
} else if (x_type == paddle::DataType::FLOAT16) {
return Quant2dPerTokenKernel<paddle::float16>(x);
} else {
PD_THROW("Quant2dPerToken not support x_type=", static_cast<int>(x_type));
return {};
}
}
std::vector<std::vector<int64_t>> Quant2dPerTokenInferShape(
const std::vector<int64_t>& x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> Quant2dPerTokenInferDtype(
const paddle::DataType& x_dtype) {
return {paddle::DataType::INT8};
}
PD_BUILD_STATIC_OP(quant2d_per_token)
.Inputs({"x"})
.Outputs({"quant_x", "x_scale"})
.SetKernelFn(PD_KERNEL(Quant2dPerToken))
.SetInferShapeFn(PD_INFER_SHAPE(Quant2dPerTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(Quant2dPerTokenInferDtype));