[NewFeature] support eplb noaux (#4725)

* support eplb noaux

* support eplb noaux

* add  eplb noaux test
This commit is contained in:
xiaoxiaohehe001
2025-11-05 20:59:12 +08:00
committed by GitHub
parent 1e88754133
commit ee37882a26
7 changed files with 477 additions and 20 deletions

View File

@@ -570,6 +570,18 @@ std::vector<paddle::Tensor> NoauxTc(
int topk,
float routed_scaling_factor);
std::vector<paddle::Tensor> NoauxTcRedundant(
paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
paddle::Tensor& expert_id_to_ep_rank_array,
paddle::Tensor& expert_in_rank_num_list,
paddle::Tensor& tokens_per_expert_stats_list,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor,
int redundant_ep_rank_num_plus_one);
#ifdef ENABLE_FP8
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
const paddle::Tensor& x,
@@ -1251,6 +1263,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
m.def("noaux_tc_redunant",&NoauxTcRedundant, "noaux_tc_redundant for MoE compute");
#ifdef ENABLE_FP8
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),

View File

@@ -0,0 +1,92 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <optional>
#include "helper.h"
#include "noauxtc_kernel.h"
std::vector<paddle::Tensor> NoauxTcRedundant(paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
paddle::Tensor& expert_id_to_ep_rank_array,
paddle::Tensor& expert_in_rank_num_list,
paddle::Tensor& tokens_per_expert_stats_list,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor,
int redundant_ep_rank_num_plus_one) {
auto input_shape = scores_with_bias.shape();
PD_CHECK(input_shape.size() == 2);
int64_t num_tokens = input_shape[0];
int64_t num_experts = input_shape[1];
auto input_type = scores_with_bias.dtype();
auto place = scores_with_bias.place();
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
auto stream = scores_with_bias.stream();
invokeNoAuxTcRedundant<float, int64_t>(reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(topk_values.data<float>()),
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()),
reinterpret_cast<int*>(expert_id_to_ep_rank_array.data<int>()),
reinterpret_cast<int*>(expert_in_rank_num_list.data<int>()),
reinterpret_cast<int*>(tokens_per_expert_stats_list.data<int>()),
num_tokens,
num_experts,
n_group,
topk_group,
topk,
routed_scaling_factor,
redundant_ep_rank_num_plus_one,
stream);
return {scores, topk_values, topk_indices};
}
std::vector<paddle::DataType> NoauxTcRedundantInferDtype(
const paddle::DataType& scores_dtype,
const paddle::DataType& scores_with_bias_dtype) {
return {scores_dtype, scores_dtype, paddle::DataType::INT64};
}
std::vector<std::vector<int64_t>> NoauxTcRedundantInferShape(
const std::vector<int64_t>& scores_shape,
const std::vector<int64_t>& ,
const int topk) {
auto num_tokens = scores_shape[0];
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
return {scores_shape, topk_values_shape, topk_indices_shape};
}
PD_BUILD_STATIC_OP(noaux_tc_redundant)
.Inputs({"scores", "scores_with_bias", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list"})
.Outputs({"output_tensor", "topk_values", "topk_indices", "tokens_per_expert_stats_list_out"})
.Attrs({"n_group: int",
"topk_group: int",
"topk:int",
"routed_scaling_factor: float",
"redundant_ep_rank_num_plus_one:int"})
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
.SetKernelFn(PD_KERNEL(NoauxTcRedundant))
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcRedundantInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcRedundantInferDtype));

View File

@@ -306,6 +306,14 @@ private:
}; // end class WarpSelect
} // namespace warp_topk
inline __device__ unsigned int xorwow_moe(unsigned int &state) {
state ^= state >> 7;
state ^= state << 9;
state ^= state >> 13;
return state;
}
template <typename T>
__device__ void topk_with_k2(T* output,
T const* input,
@@ -507,6 +515,156 @@ __global__ void group_idx_and_topk_idx_kernel(
}
}
template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_redundant_kernel(
T* scores,
T const* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
int32_t* expert_id_to_ep_rank_array,
int32_t* expert_in_rank_num_list,
int32_t* tokens_per_expert_stats_list,
int64_t const num_tokens,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
int64_t const num_experts,
int64_t const num_experts_per_group,
double routed_scaling_factor,
int64_t const redundant_ep_rank_num_plus_one) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
unsigned int state = case_id;
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts;
group_scores += case_id * n_group;
topk_values += case_id * topk;
topk_indices += case_id * topk;
int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
// store the target topk idx
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
T* s_topk_value =
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
warp_id * topk;
T value = cuda::std::numeric_limits<T>::min();
T topk_group_value = cuda::std::numeric_limits<T>::min();
int32_t num_equalto_topkth_group;
if ((n_group > topk_group) && (case_id < num_tokens)) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group) {
value = group_scores[lane_id];
}
int count_equal_to_top_value = WARP_SIZE - n_group;
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = cuda::std::numeric_limits<T>::min();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
if (case_id < num_tokens) {
for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) &&
(count_equalto_topkth_group < num_equalto_topkth_group))) {
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates = i < num_experts_per_group
? scores_with_bias[offset + i]
: cuda::std::numeric_limits<T>::min();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
count_equalto_topkth_group++;
}
}
}
queue.done();
__syncwarp();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
__syncwarp();
}
// Load the valid score value
// Calculate the summation
float topk_sum = 1e-20;
if (case_id < num_tokens) {
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value = i < topk ? scores[s_topk_idx[i]]
: 0.0f; // Load the valid value of expert
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, value, cg::plus<float>());
}
}
__syncthreads();
if (case_id < num_tokens) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__threadfence();
__syncthreads();
if (case_id < num_tokens) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
scores[s_topk_idx[i]] = value;
if (if_proceed_next_topk) {
int expert_topk = s_topk_idx[i];
int len = expert_in_rank_num_list[expert_topk];
int select = (int)xorwow_moe(state) % len;
int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
topk_indices[i] = (IdxT)selected_rank;
topk_values[i] = static_cast<T>(value);
}
else {
int expert_topk = i;
int len = expert_in_rank_num_list[expert_topk];
int select = (int)xorwow_moe(state) % len;
int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
topk_indices[i] = (IdxT)selected_rank;
topk_values[i] = static_cast<float>(1.0f / topk);
}
}
}
}
template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores,
T* group_scores,
@@ -553,6 +711,60 @@ void invokeNoAuxTc(T* scores,
routed_scaling_factor);
}
template <typename T, typename IdxT>
void invokeNoAuxTcRedundant(T* scores,
T* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
int32_t* expert_id_to_ep_rank_array,
int32_t* expert_in_rank_num_list,
int32_t* tokens_per_expert_stats_list,
int64_t const num_tokens,
int64_t const num_experts,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
double const routed_scaling_factor,
int64_t const redundant_ep_rank_num_plus_one,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores,
scores_with_bias,
num_tokens,
num_cases,
n_group,
num_experts / n_group);
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
group_idx_and_topk_idx_redundant_kernel<T><<<topk_with_k_group_num_blocks,
BLOCK_SIZE,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
routed_scaling_factor,
redundant_ep_rank_num_plus_one);
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>(T * scores, \
T * group_scores, \
@@ -568,3 +780,23 @@ void invokeNoAuxTc(T* scores,
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t);
#define INSTANTIATE_NOAUX_TC_Redundant(T, IdxT) \
template void invokeNoAuxTcRedundant<T, IdxT>(T * scores, \
T * group_scores, \
T* topk_values, \
IdxT* topk_indices, \
T * scores_with_bias, \
int32_t* expert_id_to_ep_rank_array, \
int32_t* expert_in_rank_num_list, \
int32_t* tokens_per_expert_stats_list, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
double const routed_scaling_factor, \
int64_t const redundant_ep_rank_num_plus_one, \
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC_Redundant(float, int32_t);

View File

@@ -298,6 +298,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/noaux_tc_redundant.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu",
"gpu_ops/merge_prefill_decode_output.cu",
]

View File

@@ -437,17 +437,33 @@ class EPRunner:
tokens_per_expert_stats_list,
) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx)
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
gating_logits=gate_out,
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
expert_in_rank_num_list=expert_in_rank_num_list,
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
bias=layer.gate_correction_bias,
moe_topk=self.top_k,
apply_norm_weight=True,
enable_softmax_top_k_fused=False,
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
)
if layer.topk_method == "noaux_tc":
from .moe import get_moe_scores
score, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
expert_in_rank_num_list=expert_in_rank_num_list,
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
)
else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
gating_logits=gate_out,
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
expert_in_rank_num_list=expert_in_rank_num_list,
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
bias=layer.gate_correction_bias,
moe_topk=self.top_k,
apply_norm_weight=True,
enable_softmax_top_k_fused=False,
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
)
else:
if layer.topk_method == "noaux_tc":
from .moe import get_moe_scores

View File

@@ -28,7 +28,7 @@ from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant
except:
logger.warning("import noaux_tc Failed!")
@@ -66,6 +66,10 @@ def get_moe_scores(
top_k,
routed_scaling_factor,
e_score_correction_bias,
expert_id_to_ep_rank_array=None,
expert_in_rank_num_list=None,
tokens_per_expert_stats_list=None,
redundant_ep_rank_num_plus_one=1,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
@@ -73,14 +77,28 @@ def get_moe_scores(
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
if expert_id_to_ep_rank_array is None:
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
else:
scores, topk_values, topk_idx, _ = noaux_tc_redundant(
scores,
scores_with_bias,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
redundant_ep_rank_num_plus_one,
)
return scores, topk_values, topk_idx

View File

@@ -0,0 +1,84 @@
import unittest
import paddle
from fastdeploy.model_executor.ops.gpu import noaux_tc_redundant
class TestMoeRouting(unittest.TestCase):
def setUp(self):
self.num_tokens = 10
self.num_experts = 64
self.gating_output = paddle.rand([self.num_tokens, self.num_experts])
self.e_score_correction_bias = paddle.rand([self.num_experts])
self.n_group = 8
self.topk_group = 4
self.top_k = 8
self.routed_scaling_factor = 1.5
self.redundant_ep_rank_num_plus_one = 1
def node_limit_routing(self, gate_probs):
"""将所有专家分组, 只在topk_group个group内选择专家"""
assert len(gate_probs.shape) == 2
seq_length, n_experts = gate_probs.shape
group_scores = gate_probs.reshape([seq_length, 8, -1]).topk(2, axis=-1)[0].sum(axis=-1)
group_idx = paddle.topk(group_scores, k=4, axis=-1, sorted=True)[1]
group_mask = paddle.zeros_like(group_scores).put_along_axis(
group_idx, paddle.ones([], dtype="float32"), axis=-1
)
score_mask = group_mask.unsqueeze(-1).expand([seq_length, 8, n_experts // 8]).reshape([seq_length, -1])
gate_probs = gate_probs.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
return gate_probs
def ref_moe_routing(self):
scores = paddle.nn.functional.sigmoid(self.gating_output)
prob_for_choice = scores + self.e_score_correction_bias.unsqueeze(0)
prob_for_choice = self.node_limit_routing(prob_for_choice)
top_logits, topk_idx_ref = paddle.topk(prob_for_choice, self.top_k, axis=1)
token_num, top_k = topk_idx_ref.shape
_, num_expert = prob_for_choice.shape
topk_idx_expanded = paddle.unsqueeze(topk_idx_ref, axis=-1)
indices = paddle.concat(
[
paddle.arange(token_num, dtype="int64").unsqueeze(1).tile([1, top_k]).unsqueeze(-1),
topk_idx_expanded,
],
axis=-1,
)
selected_gate_probs = paddle.gather_nd(scores, indices)
selected_gate_probs_sum = paddle.sum(selected_gate_probs, axis=1, keepdim=True)
topk_weights_ref = selected_gate_probs / selected_gate_probs_sum
topk_weights_ref = topk_weights_ref * self.routed_scaling_factor
return topk_weights_ref, topk_idx_ref
def test_moe_select(self):
scores = paddle.nn.functional.sigmoid(self.gating_output)
scores_with_bias = scores + self.e_score_correction_bias.unsqueeze(0)
expert_id_to_ep_rank_array = paddle.arange(self.num_experts, dtype="int32").reshape([self.num_experts, 1])
expert_in_rank_num_list = paddle.arange(self.num_experts, dtype="int32").reshape([self.num_experts, 1])
tokens_per_expert_stats_list = paddle.arange(self.num_experts, dtype="int32").reshape([self.num_experts, 1])
scores, topk_values, topk_idx, _ = noaux_tc_redundant(
scores,
scores_with_bias,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
self.n_group,
self.topk_group,
self.top_k,
self.routed_scaling_factor,
self.redundant_ep_rank_num_plus_one,
)
ref_topk_values, ref_topk_idx = self.ref_moe_routing()
paddle.allclose(topk_values, ref_topk_values)
paddle.allclose(topk_idx.cast(int), ref_topk_idx.cast(int))
if __name__ == "__main__":
unittest.main()