mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
@@ -151,7 +151,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||||
const bool group_moe, const bool topk_only_mode);
|
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode);
|
||||||
|
|
||||||
std::vector<paddle::Tensor>
|
std::vector<paddle::Tensor>
|
||||||
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
|
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
|
||||||
@@ -912,7 +912,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
|
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
|
||||||
py::arg("gating_output"), py::arg("gating_correction_bias"),
|
py::arg("gating_output"), py::arg("gating_correction_bias"),
|
||||||
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
|
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
|
||||||
py::arg("topk_only_mode"), "moe export dispatch function");
|
py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* moe/fused_moe/ep_moe_prefill_func.cu
|
* moe/fused_moe/ep_moe_prefill_func.cu
|
||||||
|
@@ -1296,6 +1296,18 @@ __global__ void initialize_moe_routing_kernel(
|
|||||||
dest_vec[j] = static_cast<int8_t>(round(quant_value));
|
dest_vec[j] = static_cast<int8_t>(round(quant_value));
|
||||||
}
|
}
|
||||||
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
|
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
|
||||||
|
} else if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
|
||||||
|
using StoreT = AlignedVector<OutT, VecSize>;
|
||||||
|
StoreT dest_vec;
|
||||||
|
const float max_bound = 448.f;
|
||||||
|
const float min_bound = -448.f;
|
||||||
|
for (int j = 0; j < VecSize; j++) {
|
||||||
|
float quant_value = max_bound * scale * static_cast<float>(src_vec[j]);
|
||||||
|
quant_value = quant_value > max_bound ? max_bound : quant_value;
|
||||||
|
quant_value = quant_value < min_bound ? min_bound : quant_value;
|
||||||
|
dest_vec[j] = static_cast<phi::dtype::float8_e4m3fn>(quant_value);
|
||||||
|
}
|
||||||
|
Store<phi::dtype::float8_e4m3fn, VecSize>(dest_vec, &dest_row_ptr[tid]);
|
||||||
} else {
|
} else {
|
||||||
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
|
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
|
||||||
}
|
}
|
||||||
|
@@ -113,11 +113,20 @@ void MoeDispatchKernel(
|
|||||||
permuted_rows_, moe_topk * num_rows, false, stream);
|
permuted_rows_, moe_topk * num_rows, false, stream);
|
||||||
|
|
||||||
if (w4a8_in_scale) {
|
if (w4a8_in_scale) {
|
||||||
|
if (permute_input->dtype() == paddle::DataType::INT8) {
|
||||||
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
|
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
|
||||||
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
|
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
|
||||||
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
|
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
|
||||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||||
hidden_size, moe_topk, stream);
|
hidden_size, moe_topk, stream);
|
||||||
|
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||||
|
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
|
||||||
|
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
|
||||||
|
permuted_rows_, expert_idx_per_token->data<int32_t>(),
|
||||||
|
w4a8_in_scale->data<float>(),
|
||||||
|
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||||
|
hidden_size, moe_topk, stream);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
initialize_moe_routing_kernelLauncher<data_t>::run(
|
initialize_moe_routing_kernelLauncher<data_t>::run(
|
||||||
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
|
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
|
||||||
@@ -135,7 +144,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||||
const bool group_moe, const bool topk_only_mode) {
|
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) {
|
||||||
const auto input_type = input.dtype();
|
const auto input_type = input.dtype();
|
||||||
auto place = input.place();
|
auto place = input.place();
|
||||||
int token_rows = 0;
|
int token_rows = 0;
|
||||||
@@ -151,8 +160,14 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
const int num_rows = token_rows;
|
const int num_rows = token_rows;
|
||||||
const int hidden_size = input.dims()[input_dims.size() - 1];
|
const int hidden_size = input.dims()[input_dims.size() - 1];
|
||||||
|
|
||||||
auto permute_input_dtype =
|
auto permute_input_dtype = input_type;
|
||||||
w4a8_in_scale ? paddle::DataType::INT8 : input_type;
|
if (w4a8_in_scale) {
|
||||||
|
if (moe_quant_type == "w4a8") {
|
||||||
|
permute_input_dtype = paddle::DataType::INT8;
|
||||||
|
} else if (moe_quant_type == "w4afp8") {
|
||||||
|
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
|
auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
|
||||||
permute_input_dtype, place);
|
permute_input_dtype, place);
|
||||||
@@ -285,7 +300,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
|||||||
.Outputs({"permute_input", "tokens_expert_prefix_sum",
|
.Outputs({"permute_input", "tokens_expert_prefix_sum",
|
||||||
"permute_indices_per_token", "topk_weight", "topk_idx",
|
"permute_indices_per_token", "topk_weight", "topk_idx",
|
||||||
"expert_idx_per_token"})
|
"expert_idx_per_token"})
|
||||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
|
||||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
||||||
|
@@ -204,7 +204,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
->data<float>(),
|
->data<float>(),
|
||||||
reinterpret_cast<NvType *>(fc1_out),
|
reinterpret_cast<NvType *>(fc1_out),
|
||||||
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||||
num_max_tokens_per_expert,
|
used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0],
|
||||||
num_experts,
|
num_experts,
|
||||||
inter_size,
|
inter_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -369,7 +369,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
->data<float>(),
|
->data<float>(),
|
||||||
reinterpret_cast<NvType*>(ffn_out_data),
|
reinterpret_cast<NvType*>(ffn_out_data),
|
||||||
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||||
num_max_tokens_per_expert,
|
used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0],
|
||||||
num_experts,
|
num_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
inter_size / 2,
|
inter_size / 2,
|
||||||
|
@@ -225,20 +225,9 @@ struct CollectiveMainloopFwd {
|
|||||||
const int actual_token,
|
const int actual_token,
|
||||||
const int bidn) const {
|
const int bidn) const {
|
||||||
|
|
||||||
auto g_offset = local_tile(
|
auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0));
|
||||||
mB(_, _, 0),
|
|
||||||
cute::make_shape(1, size<1>(mB)),
|
|
||||||
make_coord(pre_fix_token, _0{}));
|
|
||||||
|
|
||||||
auto g_tensor = make_tensor(
|
|
||||||
g_offset.data(),
|
|
||||||
make_layout(
|
|
||||||
cute::make_shape(actual_token, size<2>(mB)),
|
|
||||||
g_offset.stride()
|
|
||||||
));
|
|
||||||
|
|
||||||
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||||
|
|
||||||
return gB;
|
return gB;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -222,7 +222,7 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl
|
|||||||
static_cast<Element const*>(A),
|
static_cast<Element const*>(A),
|
||||||
get_gmem_layout<Batch>(M, K / 2),
|
get_gmem_layout<Batch>(M, K / 2),
|
||||||
static_cast<Element const*>(B),
|
static_cast<Element const*>(B),
|
||||||
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens * Batch : TokenPackSize, K),
|
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens: TokenPackSize, K),
|
||||||
static_cast<ElementOutput*>(C),
|
static_cast<ElementOutput*>(C),
|
||||||
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
|
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
|
||||||
weight_scale,
|
weight_scale,
|
||||||
|
@@ -276,6 +276,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
), # if set, permute_input will be int8_t
|
), # if set, permute_input will be int8_t
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
False,
|
False,
|
||||||
|
self.moe_quant_type,
|
||||||
topk_only_mode=True,
|
topk_only_mode=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -295,6 +296,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
), # if set, permute_input will be int8_t
|
), # if set, permute_input will be int8_t
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
False,
|
False,
|
||||||
|
self.moe_quant_type,
|
||||||
topk_only_mode=False,
|
topk_only_mode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -284,6 +284,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
|||||||
), # if set, permute_input will be int8_t
|
), # if set, permute_input will be int8_t
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
False,
|
False,
|
||||||
|
self.moe_quant_type,
|
||||||
topk_only_mode=False,
|
topk_only_mode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user