[XPU] refactor moe ffn (#5501)

- remove BKCL_DISPATCH_ALL_GATHER
- support sparse mode
- support moe quant_method
This commit is contained in:
zhupengyang
2025-12-18 14:14:05 +08:00
committed by GitHub
parent d0a7834a17
commit 8735cb5045
12 changed files with 397 additions and 127 deletions

View File

@@ -57,7 +57,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
const int64_t ep_size = 1;
const int64_t ep_rank = 0;
if (std::is_same<TY, int8_t>::value) {
if (std::is_same<TY, int8_t>::value && !std::is_same<TX, int8_t>::value) {
permute_input =
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
if (token_nums_this_rank > 0) {
@@ -99,7 +99,11 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
block_num,
ep_size,
ep_rank,
token_nums_this_rank);
token_nums_this_rank,
std::is_same<TX, int8_t>::value
? input_scales.get_ptr()->data<float>()
: nullptr,
expand_input_scales.data<float>());
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
}
}
@@ -138,10 +142,12 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
quant_method != "w4a8") {
APPLY_KERNEL(paddle::bfloat16, paddle::bfloat16);
} else if (input_dtype == paddle::DataType::INT8) {
APPLY_KERNEL(int8_t, int8_t);
} else {
PD_THROW("EPMoeExpertDispatch not support input_dtype=",
static_cast<int>(input_dtype),
"quant_method=",
", quant_method=",
quant_method);
return {};
}

View File

@@ -28,7 +28,6 @@
#endif
XPU_DECLARE_BOOL(MOE_FFN_USE_DENSE_INPUT, false);
XPU_DECLARE_BOOL(BKCL_DISPATCH_ALL_GATHER, false);
namespace xftblock = baidu::xpu::xftblock;
namespace api = baidu::xpu::api;
@@ -36,6 +35,7 @@ namespace api = baidu::xpu::api;
template <typename TX1, typename TX2, typename TW, typename TGEMM>
void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
xftblock::Tensor* token_num_info,
xftblock::Tensor* token_num_lod,
xftblock::Tensor* ffn1_weight,
xftblock::Tensor* ffn2_weight,
xftblock::Tensor* ffn1_bias,
@@ -44,7 +44,8 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
float* ffn2_act_scale,
TX2* ffn2_shift,
TX2* ffn2_smooth,
const int hadamard_blocksize) {
const int hadamard_blocksize,
const int64_t group_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
@@ -68,11 +69,11 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
ffn1_weight,
&ffn1_out,
ffn1_bias,
is_padding_input ? nullptr : token_num_info,
token_num_lod,
is_padding_input ? token_num_info : nullptr,
expert_num,
1, // moe_topk
0, // group_size
group_size,
ffn1_out_shape.size() == 2 ? xftblock::MoeFCInputMode::DENSE
: xftblock::MoeFCInputMode::SPARSE);
PD_CHECK(ret == 0);
@@ -81,13 +82,25 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
auto swiglu_out_shape = ffn1_out_shape;
swiglu_out_shape[swiglu_out_shape.size() - 1] /= 2;
xftblock::Tensor swiglu_out(rt_guard, xftblock_tx2, swiglu_out_shape);
ret = api::fast_swiglu<TX2>(xpu_ctx->x_context(),
ffn1_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
{token_num, inter_dim},
1,
true);
PD_CHECK(ret == 0);
if (is_padding_input) {
ret = infer_ops::swiglu_unt2<TX2>(xpu_ctx->x_context(),
ffn1_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
token_num,
inter_dim,
expert_num,
token_num_info->data<int>(),
true);
PD_CHECK(ret == 0);
} else {
ret = api::fast_swiglu<TX2>(xpu_ctx->x_context(),
ffn1_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
{token_num, inter_dim},
1,
true);
PD_CHECK(ret == 0);
}
// TODO(mayang02): use fusion_smooth_transform
if (ffn2_shift != nullptr) {
ret = api::broadcast_add<TX2>(xpu_ctx->x_context(),
@@ -109,14 +122,17 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
}
if (hadamard_blocksize > 0) {
ret = infer_ops::fast_walsh_transform<TX2>(xpu_ctx->x_context(),
swiglu_out.data<TX2>(),
nullptr,
nullptr,
swiglu_out.mutable_data<TX2>(),
hadamard_blocksize,
token_num,
outer_dim);
ret = infer_ops::fast_walsh_transform<TX2>(
xpu_ctx->x_context(),
swiglu_out.data<TX2>(),
nullptr,
nullptr,
swiglu_out.mutable_data<TX2>(),
hadamard_blocksize,
token_num,
outer_dim,
is_padding_input ? token_num / expert_num : 0,
is_padding_input ? token_num_lod->data<int>() : nullptr);
PD_CHECK(ret == 0);
}
@@ -131,11 +147,11 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
ffn2_weight,
ffn2_out,
nullptr,
is_padding_input ? nullptr : token_num_info,
token_num_lod,
is_padding_input ? token_num_info : nullptr,
expert_num,
1, // moe_topk
0, // group_size
group_size,
ffn1_out_shape.size() == 2
? xftblock::MoeFCInputMode::DENSE
: xftblock::MoeFCInputMode::SPARSE); // bias_mode
@@ -143,23 +159,25 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
}
static void convert_to_lod(xftblock::XFTContext* xctx,
xftblock::Tensor* token_num_info) {
auto rt_guard = xctx->get_rt_guard();
auto ctx = xctx->get_context();
const int expert_num = token_num_info->numel();
xftblock::Tensor tokens_num_lod(
rt_guard, xftblock::DataType::DT_INT32, {expert_num + 1});
int ret = api::constant(ctx, tokens_num_lod.data<int>(), expert_num + 1, 0);
PD_CHECK(ret == 0);
ret = api::cumsum<int>(ctx,
token_num_info->data<int>(),
tokens_num_lod.data<int>() + 1,
{expert_num},
false,
false,
0);
PD_CHECK(ret == 0);
*token_num_info = std::move(tokens_num_lod);
xftblock::Tensor* token_num_info,
xftblock::Tensor* token_num_lod,
int expert_num) {
if (expert_num == token_num_info->numel()) {
auto rt_guard = xctx->get_rt_guard();
auto ctx = xctx->get_context();
int ret = api::constant(ctx, token_num_lod->data<int>(), expert_num + 1, 0);
PD_CHECK(ret == 0);
ret = api::cumsum<int>(ctx,
token_num_info->data<int>(),
token_num_lod->data<int>() + 1,
{expert_num},
false,
false,
0);
PD_CHECK(ret == 0);
} else {
*token_num_lod = std::move(*token_num_info);
}
}
template <typename TX1, typename TX2, typename TW>
@@ -196,6 +214,12 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
int inter_dim = ffn1_w_shape[1];
int outer_dim = inter_dim / 2;
bool is_padding_input = input_shape.size() == 3;
int64_t group_size = 0;
if (ffn1_weight_scale.get_ptr() &&
ffn1_weight_scale.get_ptr()->numel() > expert_num * inter_dim) {
group_size = hidden_dim / (ffn1_weight_scale.get_ptr()->numel() /
expert_num / inter_dim);
}
if (is_padding_input) {
PD_CHECK(input_shape[0] == expert_num);
PD_CHECK(token_num_info.numel() == expert_num,
@@ -206,7 +230,9 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
expert_num);
}
bool is_w4 = quant_method == "w4a8" || quant_method == "weight_only_int4";
bool is_w4 = quant_method == "w_channelwise_int4_a_tokenwise_int15" ||
quant_method == "w_channelwise_int4_a_expertwise_int8" ||
quant_method == "w_channelwise_int4_a_tokenwise_int8";
auto xftblock_tx1 = xftblock::DataTypeToEnum<XPU_TX1>::value;
auto xftblock_tx2 = xftblock::DataTypeToEnum<XPU_TX2>::value;
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
@@ -256,6 +282,8 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xftblock::Tensor xtoken_num_info(const_cast<int*>(token_num_info.data<int>()),
xftblock::DataType::DT_INT32,
token_num_info.shape());
xftblock::Tensor xtoken_num_lod(
rt_guard, xftblock::DataType::DT_INT32, {expert_num + 1});
XPU_TX2* shift_data = nullptr;
XPU_TX2* smooth_data = nullptr;
if (ffn2_shift.get_ptr()) {
@@ -272,9 +300,10 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xftblock::Tensor xffn2_out;
paddle::Tensor ffn1_in_dense;
paddle::Tensor ffn1_in_scale_per_token;
convert_to_lod(&xctx, &xtoken_num_info, &xtoken_num_lod, expert_num);
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
convert_to_lod(&xctx, &xtoken_num_info);
if (quant_method == "w4a8") {
if (quant_method == "w_channelwise_int4_a_expertwise_int8" ||
quant_method == "w_channelwise_int4_a_tokenwise_int8") {
ffn1_in_scale_per_token = paddle::empty(
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
ffn1_in_dense = paddle::empty({valid_token_num, hidden_dim},
@@ -292,7 +321,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
ffn1_act_scale_data,
ffn1_in_scale_per_token.data<float>(),
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
1,
@@ -302,7 +331,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
reinterpret_cast<const int8_t*>(ffn_in.data<int8_t>()),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
@@ -313,7 +342,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
ffn1_act_scale_data,
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
ffn1_in_scale_per_token.data<float>(),
expert_num,
@@ -324,6 +353,64 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
input_shape[1]);
PD_CHECK(ret == 0);
}
} else if (quant_method == "w_channelwise_int8_a_expertwise_int8" ||
quant_method == "w_channelwise_int8_a_tokenwise_int8") {
ffn1_in_scale_per_token = paddle::empty(
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
ffn1_in_dense = paddle::empty({valid_token_num, hidden_dim},
paddle::DataType::INT8,
ffn_in.place());
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<int8_t>(),
nullptr,
ffn1_in_scale_per_token.data<float>(),
xftblock::DataType::DT_INT8,
{valid_token_num, hidden_dim});
if (std::is_same<XPU_TX1, int8_t>::value) {
PD_CHECK(ffn1_act_scale_data != nullptr,
"need ffn1_act_scale for x int8 per expert input");
ret = infer_ops::sequence_unpad<float, int>(
xpu_ctx->x_context(),
ffn1_act_scale_data,
ffn1_in_scale_per_token.data<float>(),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
1,
true);
PD_CHECK(ret == 0);
ret = infer_ops::sequence_unpad<int8_t, int>(
xpu_ctx->x_context(),
reinterpret_cast<const int8_t*>(ffn_in.data<int8_t>()),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
true);
PD_CHECK(ret == 0);
} else {
auto ffn1_in_unpad = paddle::empty(
{valid_token_num, hidden_dim}, ffn_in.dtype(), ffn_in.place());
ret = infer_ops::sequence_unpad<XPU_TX1, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
reinterpret_cast<XPU_TX1*>(ffn1_in_unpad.data<TX1>()),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
true);
PD_CHECK(ret == 0);
ret = infer_ops::quant2d_per_token<XPU_TX1, float, int8_t>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn1_in_unpad.data<TX1>()),
nullptr,
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
ffn1_in_scale_per_token.data<float>(),
valid_token_num,
hidden_dim);
PD_CHECK(ret == api::SUCCESS);
}
} else {
ffn1_in_dense = paddle::empty(
{valid_token_num, hidden_dim}, ffn_in.dtype(), ffn_in.place());
@@ -336,7 +423,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
reinterpret_cast<XPU_TX1*>(xffn1_in.data<XPU_TX1>()),
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
@@ -345,31 +432,6 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
}
xffn2_out =
xftblock::Tensor(rt_guard, xftblock_tx2, {valid_token_num, hidden_dim});
} else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input &&
quant_method == "w4a8") {
convert_to_lod(&xctx, &xtoken_num_info);
ffn1_in_scale_per_token = paddle::empty(
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
ffn1_in_dense = paddle::empty(
{valid_token_num, hidden_dim}, paddle::DataType::INT8, ffn_in.place());
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<int8_t>(),
nullptr,
ffn1_in_scale_per_token.data<float>(),
xftblock::DataType::DT_INT8,
{valid_token_num, hidden_dim});
ret = infer_ops::quant2d_per_expert<XPU_TX1>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
ffn1_act_scale_data,
xtoken_num_info.data<int>(),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
ffn1_in_scale_per_token.data<float>(),
expert_num,
valid_token_num,
hidden_dim);
PD_CHECK(ret == 0);
xffn2_out =
xftblock::Tensor(ffn2_out.data<TX2>(), xftblock_tx2, input_shape);
} else {
xffn1_in = xftblock::Tensor(const_cast<TX1*>(ffn_in.data<TX1>()),
nullptr,
@@ -383,6 +445,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
#define FFN_IMPL(TX1, TX2, TW, TGEMM) \
MoeExpertFFNImpl<TX1, TX2, TW, TGEMM>(&xffn1_in, \
&xtoken_num_info, \
&xtoken_num_lod, \
&xffn1_w, \
&xffn2_w, \
xffn1_bias.get(), \
@@ -391,31 +454,32 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
ffn2_act_scale_data, \
shift_data, \
smooth_data, \
hadamard_blocksize)
if (quant_method == "weight_only_int8") {
static const char* xft_moe_fc_wint8_tgemm =
std::getenv("XFT_MOE_FC_WINT8_TGEMM");
if (xft_moe_fc_wint8_tgemm != nullptr) {
if (std::string(xft_moe_fc_wint8_tgemm) == "INT8") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, int8_wo_t);
} else if (std::string(xft_moe_fc_wint8_tgemm) == "FLOAT16") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float16);
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float);
}
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float);
}
} else if (quant_method == "weight_only_int4") {
hadamard_blocksize, \
group_size)
if (quant_method == "w_channelwise_int8_a_float32") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float);
} else if (quant_method == "w_channelwise_int8_a_tokenwise_float16") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float16);
} else if (quant_method == "w_channelwise_int8_a_tokenwise_int15") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, int8_wo_t);
} else if (quant_method == "w_channelwise_int4_a_tokenwise_int15") {
FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int15);
} else if (quant_method == "w4a8") {
} else if (quant_method == "w_channelwise_int4_a_expertwise_int8" ||
quant_method == "w_channelwise_int4_a_tokenwise_int8") {
// a8: per expert
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8);
} else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input) {
FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8);
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int8);
}
} else if (quant_method == "w_channelwise_int8_a_expertwise_int8" ||
quant_method == "w_channelwise_int8_a_tokenwise_int8") {
// a8: per expert
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
FFN_IMPL(int8_t, XPU_TX2, int8_t, int8_t);
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, int8_t);
}
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, XPU_TW, float);
}
@@ -425,7 +489,7 @@ std::vector<paddle::Tensor> MoeExpertFFNKernel(
xpu_ctx->x_context(),
const_cast<XPU_TX2*>(xffn2_out.data<XPU_TX2>()),
reinterpret_cast<XPU_TX2*>(ffn2_out.data<TX2>()),
xtoken_num_info.data<int>(),
xtoken_num_lod.data<int>(),
input_shape[0],
input_shape[1],
input_shape[2],

View File

@@ -143,6 +143,8 @@ std::vector<paddle::Tensor> WeightOnlyLinear(
const int arch,
const int group_size);
std::vector<paddle::Tensor> Quant2dPerToken(const paddle::Tensor& x);
std::vector<paddle::Tensor> MoeEPCombine(const paddle::Tensor& ffn_out,
const paddle::Tensor& moe_index,
const paddle::Tensor& weights,
@@ -1275,6 +1277,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("arch"),
py::arg("group_size") = -1);
m.def(
"quant2d_per_token", &Quant2dPerToken, py::arg("x"), "quant x per token");
m.def("xpu_moe_layer",
&MoeLayer,
py::arg("x"),

View File

@@ -0,0 +1,88 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <core/check.h>
#include <core/context.h>
#include <core/param.h>
#include <infer_ops.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#include "utility/env.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace xftblock = baidu::xpu::xftblock;
namespace api = baidu::xpu::api;
template <typename TX>
std::vector<paddle::Tensor> Quant2dPerTokenKernel(const paddle::Tensor& x) {
using XPU_TX = typename XPUTypeTrait<TX>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();
auto input_shape = x.shape();
auto x_scale =
paddle::empty({input_shape[0]}, paddle::DataType::FLOAT32, x.place());
auto quant_x = paddle::empty(
{input_shape[0], input_shape[1]}, paddle::DataType::INT8, x.place());
if (input_shape[0] > 0) {
int ret = infer_ops::quant2d_per_token<XPU_TX, float, int8_t>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(x.data<TX>()),
nullptr,
reinterpret_cast<int8_t*>(quant_x.data<int8_t>()),
reinterpret_cast<float*>(x_scale.data<float>()),
input_shape[0],
input_shape[1]);
PD_CHECK(ret == api::SUCCESS);
}
return {quant_x, x_scale};
}
std::vector<paddle::Tensor> Quant2dPerToken(const paddle::Tensor& x) {
const auto x_type = x.dtype();
if (x_type == paddle::DataType::BFLOAT16) {
return Quant2dPerTokenKernel<paddle::bfloat16>(x);
} else if (x_type == paddle::DataType::FLOAT16) {
return Quant2dPerTokenKernel<paddle::float16>(x);
} else {
PD_THROW("Quant2dPerToken not support x_type=", static_cast<int>(x_type));
return {};
}
}
std::vector<std::vector<int64_t>> Quant2dPerTokenInferShape(
const std::vector<int64_t>& x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> Quant2dPerTokenInferDtype(
const paddle::DataType& x_dtype) {
return {paddle::DataType::INT8};
}
PD_BUILD_STATIC_OP(quant2d_per_token)
.Inputs({"x"})
.Outputs({"quant_x", "x_scale"})
.SetKernelFn(PD_KERNEL(Quant2dPerToken))
.SetInferShapeFn(PD_INFER_SHAPE(Quant2dPerTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(Quant2dPerTokenInferDtype));

View File

@@ -157,6 +157,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS": lambda: int(os.getenv("FD_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", "500")),
"FD_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE": lambda: int(os.getenv("FD_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", "64")),
"FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT": lambda: int(os.getenv("FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", "120")),
"FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""),
}

View File

@@ -352,9 +352,9 @@ class XPUEPPrefillRunner(XPUEPRunner):
**kwargs,
):
self.num_combined_tokens = x.shape[0]
x_scale_tensor = kwargs.get("x_scale_tensor", None)
x_scale = kwargs.get("x_scale", None)
dispatch_args = {
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
"x": (x, x_scale) if x_scale is not None else x,
"topk_idx": topk_idx,
"topk_weights": topk_weights,
}
@@ -428,11 +428,27 @@ class XPUEPDecoderRunner(XPUEPRunner):
dispatch_hook,
valid_token_num,
) = self.ep_engine.low_latency_dispatch(x, topk_idx, expertwise_scale, use_fp8)
# no need to call dispatch_hook here, because it has already been done in xDeepEP
# if dispatch_hook is not None:
# dispatch_hook()
# valid_token_num is optional:
# - if valid_token_num is None, it means that we CANNOT accurately know
# the size of the tensor, but the advantage is that it can reduce
# the overhead of kernel launch.
# - if valid_token_num is NOT None, it means that we CAN accurately know
# the size of the tensor, but the disadvantage is that it will interrupt
# the process of kernel launch.
if valid_token_num is None and dispatch_hook is not None:
dispatch_hook()
return recv_hidden_states, recv_expert_count, handle, valid_token_num
if valid_token_num is None:
valid_token_num = -1
if isinstance(recv_hidden_states, tuple):
recv_x = recv_hidden_states[0]
recv_x_scale = recv_hidden_states[1]
else:
recv_x = recv_hidden_states
recv_x_scale = None
return recv_x, recv_x_scale, recv_expert_count, handle, valid_token_num
def combine(self, ffn_out, topk_idx, topk_weights, handle):
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(

View File

@@ -14,11 +14,14 @@
# limitations under the License.
"""
import os
from typing import Callable
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
from fastdeploy.model_executor.layers.utils import get_tensor
@@ -27,6 +30,7 @@ from fastdeploy.model_executor.ops.xpu import (
ep_moe_expert_dispatch,
moe_expert_ffn,
moe_topk_select,
quant2d_per_token,
weight_quantize_xpu,
xpu_moe_layer,
)
@@ -46,8 +50,10 @@ class XPUMoEMethod(MoEMethodBase):
def __init__(
self,
quant_config: WeightOnlyConfig,
layer,
) -> None:
super().__init__(quant_config)
self.layer_idx = getattr(layer, "layer_idx", -1)
if self.moe_quant_type in ["w16a16"]:
self.weight_dtype = "bfloat16"
@@ -57,6 +63,68 @@ class XPUMoEMethod(MoEMethodBase):
raise ValueError(f"Unsupported moe quant type: {self.moe_quant_type}")
self.scale_dtype = "float32"
self.bias_dtype = "float32"
self._set_xpu_moe_quant_type()
def _set_xpu_moe_quant_type(self):
"""
XPU_MOE_FFN_QUANT_TYPE_MAP options:
- defalut:
- w16a16 -> w_bfloat16_a_bfloat16
- weight_only_int8 -> w_channelwise_int8_a_tokenwise_float16
- weight_only_int4 -> w_channelwise_int4_a_tokenwise_int15
- w4a8 -> w_channelwise_int4_a_expertwise_int8
- w_bfloat16_a_bfloat16
- w_channelwise_int8_a_float32
- w_channelwise_int8_a_tokenwise_float16
- w_channelwise_int8_a_tokenwise_int15
- w_channelwise_int8_a_expertwise_int8
- w_channelwise_int8_a_tokenwise_int8
- w_channelwise_int4_a_tokenwise_int15
- w_channelwise_int4_a_expertwise_int8
- w_channelwise_int4_a_tokenwise_int8
- TODO:
- w_groupwise_int4_a_expertwise_int8
- w_groupwise_int4_a_tokenwise_int8
for example: XPU_MOE_FFN_QUANT_TYPE_MAP="w_channelwise_int8_a_tokenwise_float16:3->5,7->9;w_channelwise_int8_a_tokenwise_int8:6,10->20"
"""
if self.layer_idx < 0:
return
xpu_moe_ffn_quant_type_map = envs.FD_XPU_MOE_FFN_QUANT_TYPE_MAP
self.xpu_moe_quant_type = "default"
for quant_type_map in xpu_moe_ffn_quant_type_map.split(";"):
quant_type_info = quant_type_map.split(":")
if len(quant_type_info) != 2:
continue
for ids_info in quant_type_info[1].split(","):
ids = ids_info.split("->")
id_min = int(ids[0])
id_max = int(ids[-1])
if id_min <= self.layer_idx <= id_max:
self.xpu_moe_quant_type = quant_type_info[0]
if self.xpu_moe_quant_type == "default":
default_quant_type_map = {
"w16a16": "w_bfloat16_a_bfloat16",
"weight_only_int8": "w_channelwise_int8_a_tokenwise_float16",
"weight_only_int4": "w_channelwise_int4_a_tokenwise_int15",
"w4a8": "w_channelwise_int4_a_expertwise_int8",
}
assert (
self.moe_quant_type in default_quant_type_map.keys()
), f"Unsupported moe quant type: {self.moe_quant_type}"
self.xpu_moe_quant_type = default_quant_type_map[self.moe_quant_type]
# TODO(zhupengyang): remove XFT_MOE_FC_WINT8_TGEMM later
if self.moe_quant_type == "weight_only_int8":
xft_moe_fc_wint8_tgemm = os.environ.get("XFT_MOE_FC_WINT8_TGEMM", "")
if xft_moe_fc_wint8_tgemm == "FLOAT16":
self.xpu_moe_quant_type = "w_channelwise_int8_a_tokenwise_float16"
elif xft_moe_fc_wint8_tgemm == "INT8":
self.xpu_moe_quant_type = "w_channelwise_int8_a_tokenwise_int8"
else:
assert xft_moe_fc_wint8_tgemm == "", f"Unsupported XFT_MOE_FC_WINT8_TGEMM={xft_moe_fc_wint8_tgemm}"
logger.info(f"moe_layer_idx: {self.layer_idx}; xpu_moe_quant_type: {self.xpu_moe_quant_type}")
def import_backend_ep_runner(self) -> None:
from .ep import XPUEPDecoderRunner, XPUEPPrefillRunner
@@ -283,7 +351,7 @@ class XPUMoEMethod(MoEMethodBase):
permute_indices_per_token,
token_num_lod,
dst_weights,
ffn1_act_scale_per_token,
ffn1_x_scale_per_token,
) = ep_moe_expert_dispatch(
x,
topk_idx,
@@ -294,14 +362,16 @@ class XPUMoEMethod(MoEMethodBase):
self.moe_quant_type,
)
if not hasattr(layer, self.added_in_scale_attrs[0]):
ffn1_act_scale_per_token = None
if hasattr(layer, self.added_in_scale_attrs[0]):
ffn1_x_scale = ffn1_x_scale_per_token
else:
ffn1_x_scale = None
ffn_out = self.compute_ffn(
layer,
permute_input,
ffn1_x_scale,
token_num_lod,
x.shape[0] * layer.top_k,
ffn1_act_scale_per_token,
)
topk_weights_bf16 = topk_weights.astype("bfloat16")
@@ -337,10 +407,10 @@ class XPUMoEMethod(MoEMethodBase):
def compute_ffn(
self,
layer: nn.Layer,
permute_input,
ffn1_x,
ffn1_x_scale,
token_num_lod,
valid_token_num,
ffn1_act_scale_per_token=None,
):
"""
Calculate moe
@@ -350,19 +420,19 @@ class XPUMoEMethod(MoEMethodBase):
else:
hadamard_block_size = -1
ffn_out = moe_expert_ffn(
permute_input,
ffn1_x,
token_num_lod,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
None,
ffn1_act_scale_per_token,
ffn1_x_scale,
getattr(layer, self.added_in_scale_attrs[1], None),
getattr(layer, self.added_scale_attrs[0], None),
getattr(layer, self.added_scale_attrs[1], None),
None,
None,
self.moe_quant_type,
self.xpu_moe_quant_type,
hadamard_block_size,
valid_token_num,
)
@@ -381,9 +451,12 @@ class XPUMoEMethod(MoEMethodBase):
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
# 2. Dynamic compute blockwise quantization scales
# x, x_scale_tensor = fastdeploy.model_executor.ops.xpu.per_token_quant(x)
x_scale_tensor = None
if "a_tokenwise_int8" in self.xpu_moe_quant_type:
x, x_scale = quant2d_per_token(x)
else:
x_scale = None
# 3. EP Dispatch
(
recv_x,
@@ -396,20 +469,24 @@ class XPUMoEMethod(MoEMethodBase):
x,
topk_idx,
topk_weights,
x_scale_tensor=x_scale_tensor,
x_scale=x_scale,
)
# 4. Compute ffn
token_num_per_expert = recv_num_tokens_per_expert_list.numpy().tolist()
token_all_num = sum(token_num_per_expert)
# 4. Compute ffn
moe_dispatch_scale = None
if "a_expertwise_int8" in self.xpu_moe_quant_type:
moe_dispatch_scale = getattr(layer, self.added_in_scale_attrs[0])
elif "a_tokenwise_int8" in self.xpu_moe_quant_type:
moe_dispatch_scale = recv_x_scales
else:
moe_dispatch_scale = None
(
permute_input,
permute_indices_per_token,
token_num_lod,
dst_weights,
ffn1_act_scale_per_token,
ffn1_x_scale_per_token,
) = ep_moe_expert_dispatch(
recv_x,
recv_topk_idx,
@@ -420,14 +497,18 @@ class XPUMoEMethod(MoEMethodBase):
self.moe_quant_type,
)
if "a_expertwise_int8" in self.xpu_moe_quant_type or "a_tokenwise_int8" in self.xpu_moe_quant_type:
ffn1_x_scale = ffn1_x_scale_per_token
else:
ffn1_x_scale = None
ffn_out = self.compute_ffn(
layer,
permute_input,
ffn1_x_scale,
token_num_lod,
token_all_num,
)
# prmt back per rank
recv_topk_weights_bf16 = recv_topk_weights.astype("bfloat16")
tmp_ffn_out = ep_moe_expert_combine(
ffn_out,
@@ -459,10 +540,15 @@ class XPUMoEMethod(MoEMethodBase):
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
# 2. EP Dispatch
expertwise_scale = None
use_fp8 = False
if "a_tokenwise_int8" in self.xpu_moe_quant_type:
use_fp8 = True
expertwise_scale = None
else:
use_fp8 = False
expertwise_scale = None
(
permute_input,
recv_x,
recv_x_scale,
token_nums_per_expert,
handle,
valid_token_num,
@@ -470,14 +556,15 @@ class XPUMoEMethod(MoEMethodBase):
x,
topk_idx,
topk_weights,
expertwise_scale=expertwise_scale,
use_fp8=use_fp8,
expertwise_scale=expertwise_scale,
)
# 3. Compute ffn
ffn_out = self.compute_ffn(
layer,
permute_input,
recv_x,
recv_x_scale,
token_nums_per_expert,
valid_token_num,
)

View File

@@ -836,6 +836,7 @@ class RowParallelLinear(LinearBase):
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.layer_id = layer_id
self.split_token = (
fd_config.parallel_config.use_sequence_parallel_moe
and layer_id >= fd_config.model_config.moe_layer_start_index

View File

@@ -42,7 +42,7 @@ except:
import numpy as np
def get_moe_method():
def get_moe_method(layer=None):
"""
return moe method based on device platform
"""
@@ -54,7 +54,7 @@ def get_moe_method():
elif current_platform.is_xpu():
from fastdeploy.model_executor.layers.backends import XPUMoEMethod
return XPUMoEMethod(None)
return XPUMoEMethod(None, layer)
elif current_platform.is_gcu():
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
@@ -223,7 +223,7 @@ class FusedMoE(nn.Layer):
self.moe_quant_type = moe_quant_config.name()
else:
# unquantized quant_method
self.quant_method = get_moe_method()
self.quant_method = get_moe_method(self)
assert self.quant_method is not None, "self.quant_method should not be None"
self.redundant_table_manger = redundant_table_manger
self.is_rearrange = False

View File

@@ -52,6 +52,6 @@ class W4A8Config(QuantConfigBase):
XPUW4A8MoEMethod,
)
return XPUW4A8MoEMethod(self)
return XPUW4A8MoEMethod(self, layer)
else:
raise ValueError(f"Unsupported layer type {type(layer)} for w4a8")

View File

@@ -101,7 +101,7 @@ class WeightOnlyConfig(QuantConfigBase):
XPUWeightOnlyMoEMethod,
)
return XPUWeightOnlyMoEMethod(self)
return XPUWeightOnlyMoEMethod(self, layer)
else:
from fastdeploy.model_executor.layers.backends import (
XPUWeightOnlyLinearMethod,

View File

@@ -632,6 +632,7 @@ export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export MOE_FFN_USE_DENSE_INPUT=1
export FD_XPU_MOE_FFN_QUANT_TYPE_MAP="w_channelwise_int8_a_tokenwise_int8:8->53"
export port_num=$((8188 + XPU_ID * 100))
# 启动服务
@@ -643,7 +644,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--data-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 64 \
--quantization "wint4" \
--quantization "wint8" \
--engine-worker-queue-port $((port_num + 10)) \
--metrics-port $((port_num + 2)) \
--cache-queue-port $((port_num + 47873)) \
@@ -692,6 +693,7 @@ unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset MOE_FFN_USE_DENSE_INPUT
unset XPU_MOE_FFN_QUANT_TYPE_MAP
stop_processes >kill.log 2>&1
if [ ${ep_online_exit_code} -ne 0 ]; then