【New Feature】集中式支持w4afp8 (#3644)

* 支持tp w4afp8

* code style
This commit is contained in:
yangjianfengo1
2025-08-28 10:53:24 +08:00
committed by GitHub
parent 76513f6416
commit e81046fdad
8 changed files with 41 additions and 22 deletions

View File

@@ -151,7 +151,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const bool topk_only_mode);
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode);
std::vector<paddle::Tensor>
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
@@ -912,7 +912,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
py::arg("gating_output"), py::arg("gating_correction_bias"),
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
py::arg("topk_only_mode"), "moe export dispatch function");
py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function");
/**
* moe/fused_moe/ep_moe_prefill_func.cu

View File

@@ -1296,6 +1296,18 @@ __global__ void initialize_moe_routing_kernel(
dest_vec[j] = static_cast<int8_t>(round(quant_value));
}
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
} else if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
using StoreT = AlignedVector<OutT, VecSize>;
StoreT dest_vec;
const float max_bound = 448.f;
const float min_bound = -448.f;
for (int j = 0; j < VecSize; j++) {
float quant_value = max_bound * scale * static_cast<float>(src_vec[j]);
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
dest_vec[j] = static_cast<phi::dtype::float8_e4m3fn>(quant_value);
}
Store<phi::dtype::float8_e4m3fn, VecSize>(dest_vec, &dest_row_ptr[tid]);
} else {
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
}

View File

@@ -113,11 +113,20 @@ void MoeDispatchKernel(
permuted_rows_, moe_topk * num_rows, false, stream);
if (w4a8_in_scale) {
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
if (permute_input->dtype() == paddle::DataType::INT8) {
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
permuted_rows_, expert_idx_per_token->data<int32_t>(),
w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
}
} else {
initialize_moe_routing_kernelLauncher<data_t>::run(
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
@@ -135,7 +144,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const bool topk_only_mode) {
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) {
const auto input_type = input.dtype();
auto place = input.place();
int token_rows = 0;
@@ -151,8 +160,14 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const int num_rows = token_rows;
const int hidden_size = input.dims()[input_dims.size() - 1];
auto permute_input_dtype =
w4a8_in_scale ? paddle::DataType::INT8 : input_type;
auto permute_input_dtype = input_type;
if (w4a8_in_scale) {
if (moe_quant_type == "w4a8") {
permute_input_dtype = paddle::DataType::INT8;
} else if (moe_quant_type == "w4afp8") {
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
}
}
auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
permute_input_dtype, place);
@@ -285,7 +300,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
.Outputs({"permute_input", "tokens_expert_prefix_sum",
"permute_indices_per_token", "topk_weight", "topk_idx",
"expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));

View File

@@ -204,7 +204,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
->data<float>(),
reinterpret_cast<NvType *>(fc1_out),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
num_max_tokens_per_expert,
used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0],
num_experts,
inter_size,
hidden_size,
@@ -369,7 +369,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
->data<float>(),
reinterpret_cast<NvType*>(ffn_out_data),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
num_max_tokens_per_expert,
used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0],
num_experts,
hidden_size,
inter_size / 2,

View File

@@ -225,20 +225,9 @@ struct CollectiveMainloopFwd {
const int actual_token,
const int bidn) const {
auto g_offset = local_tile(
mB(_, _, 0),
cute::make_shape(1, size<1>(mB)),
make_coord(pre_fix_token, _0{}));
auto g_tensor = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_token, size<2>(mB)),
g_offset.stride()
));
auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0));
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
return gB;
}

View File

@@ -222,7 +222,7 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl
static_cast<Element const*>(A),
get_gmem_layout<Batch>(M, K / 2),
static_cast<Element const*>(B),
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens * Batch : TokenPackSize, K),
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens: TokenPackSize, K),
static_cast<ElementOutput*>(C),
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
weight_scale,

View File

@@ -276,6 +276,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=True,
)
else:
@@ -295,6 +296,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)

View File

@@ -284,6 +284,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)