diff --git a/custom_ops/gpu_ops/noaux_tc.cu b/custom_ops/gpu_ops/noaux_tc.cu index c92822eb9..7b6d432c8 100644 --- a/custom_ops/gpu_ops/noaux_tc.cu +++ b/custom_ops/gpu_ops/noaux_tc.cu @@ -33,10 +33,14 @@ std::vector NoauxTc(paddle::Tensor& scores, auto input_type = scores_with_bias.dtype(); auto place = scores_with_bias.place(); auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place); + auto topk_values = paddle::empty({num_tokens, topk}, input_type, place); + auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT32, place); auto stream = scores_with_bias.stream(); - invokeNoAuxTc(reinterpret_cast(scores.data()), + invokeNoAuxTc(reinterpret_cast(scores.data()), reinterpret_cast(group_scores.data()), + reinterpret_cast(topk_values.data()), + reinterpret_cast(topk_indices.data()), reinterpret_cast(scores_with_bias.data()), num_tokens, num_experts, @@ -46,19 +50,23 @@ std::vector NoauxTc(paddle::Tensor& scores, routed_scaling_factor, stream); - return {scores}; + return {scores, topk_values, topk_indices}; } std::vector NoauxTcInferDtype( const paddle::DataType& scores_dtype, const paddle::DataType& scores_with_bias_dtype) { - return {scores_dtype}; + return {scores_dtype, scores_dtype, paddle::DataType::INT32}; } std::vector> NoauxTcInferShape( const std::vector& scores_shape, - const std::vector& gating_output_shape) { - return {scores_shape}; + const std::vector& , + const int topk) { + auto num_tokens = scores_shape[0]; + auto topk_values_shape = std::vector{num_tokens, topk}; + auto topk_indices_shape = std::vector{num_tokens, topk}; + return {scores_shape, topk_values_shape, topk_indices_shape}; } PD_BUILD_STATIC_OP(noaux_tc) diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index c91d4f5b3..e8a3f4508 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -372,10 +372,12 @@ __global__ void topk_with_k2_kernel(T* output, } } -template +template __global__ void group_idx_and_topk_idx_kernel( T* scores, T const* group_scores, + T* topk_values, + IdxT* topk_indices, T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, @@ -391,6 +393,8 @@ __global__ void group_idx_and_topk_idx_kernel( scores_with_bias += case_id * num_experts; scores += case_id * num_experts; group_scores += case_id * n_group; + topk_values += case_id * topk; + topk_indices += case_id * topk; int32_t align_num_experts_per_group = warp_topk::round_up_to_multiple_of(num_experts_per_group); @@ -436,6 +440,7 @@ __global__ void group_idx_and_topk_idx_kernel( queue((int32_t)topk, cuda::std::numeric_limits::min()); int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits::min()); if (case_id < num_tokens) { for (int i_group = 0; i_group < n_group; i_group++) { if ((group_scores[i_group] > topk_group_value) || @@ -490,13 +495,23 @@ __global__ void group_idx_and_topk_idx_kernel( for (int i = lane_id; i < topk; i += WARP_SIZE) { float value = s_topk_value[i] / topk_sum * routed_scaling_factor; scores[s_topk_idx[i]] = value; + if (if_proceed_next_topk) { + topk_indices[i] = s_topk_idx[i]; + topk_values[i] = static_cast(value); + } + else { + topk_indices[i] = i; + topk_values[i] = static_cast(1.0f / topk); + } } } } -template +template void invokeNoAuxTc(T* scores, T* group_scores, + T* topk_values, + IdxT* topk_indices, T* scores_with_bias, int64_t const num_tokens, int64_t const num_experts, @@ -526,6 +541,8 @@ void invokeNoAuxTc(T* scores, dynamic_smem_in_bytes, stream>>>(scores, group_scores, + topk_values, + topk_indices, scores_with_bias, num_tokens, n_group, @@ -536,9 +553,11 @@ void invokeNoAuxTc(T* scores, routed_scaling_factor); } -#define INSTANTIATE_NOAUX_TC(T) \ - template void invokeNoAuxTc(T * scores, \ +#define INSTANTIATE_NOAUX_TC(T, IdxT) \ + template void invokeNoAuxTc(T * scores, \ T * group_scores, \ + T* topk_values, \ + IdxT* topk_indices, \ T * scores_with_bias, \ int64_t const num_tokens, \ int64_t const num_experts, \ @@ -548,4 +567,4 @@ void invokeNoAuxTc(T* scores, double const routed_scaling_factor, \ cudaStream_t const stream); -INSTANTIATE_NOAUX_TC(float); +INSTANTIATE_NOAUX_TC(float, int32_t); diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 752ead74f..cb717f963 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -31,6 +31,35 @@ import fastdeploy from fastdeploy.config import MoEPhase from fastdeploy.utils import singleton +try: + from fastdeploy.model_executor.ops.gpu import noaux_tc +except: + logger.warning("import noaux_tc Failed!") + + +def get_moe_scores( + gating_output: paddle.Tensor, + n_group, + topk_group, + top_k, + routed_scaling_factor, + e_score_correction_bias, +) -> paddle.Tensor: + """ + compute moe scores using e_score_correction_bias. + """ + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + scores, topk_values, topk_idx = noaux_tc( + scores, + scores_with_bias, + n_group, + topk_group, + top_k, + routed_scaling_factor, + ) + return scores, topk_values, topk_idx + @singleton class DeepEPEngine: @@ -284,13 +313,23 @@ class EPRunner: redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, ) else: - topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - self.top_k, - True, # apply_norm_weight, - False, - ) + if layer.topk_method == "noaux_tc": + score, topk_weights, topk_idx = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + ) + else: + topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + self.top_k, + True, # apply_norm_weight, + False, + ) return topk_idx, topk_weights @abstractmethod diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 87ac8fbcb..458da642f 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -53,7 +53,7 @@ def get_moe_scores( """ scores = paddle.nn.functional.sigmoid(gating_output) scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) - scores = noaux_tc( + scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, n_group, @@ -61,7 +61,7 @@ def get_moe_scores( top_k, routed_scaling_factor, ) - return scores + return scores, topk_values, topk_idx class CutlassMoEMethod(MoEMethodBase): @@ -248,7 +248,7 @@ class CutlassMoEMethod(MoEMethodBase): Paddle Cutlass compute Fused MoE. """ if layer.topk_method == "noaux_tc": - gate_out = get_moe_scores( + gate_out, _, _ = get_moe_scores( gate_out, layer.n_group, layer.topk_group, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index 848f52b95..7866c03d6 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -41,7 +41,7 @@ def get_moe_scores( """ scores = paddle.nn.functional.sigmoid(gating_output) scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) - scores = noaux_tc( + scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, n_group, @@ -49,7 +49,7 @@ def get_moe_scores( top_k, routed_scaling_factor, ) - return scores + return scores, topk_values, topk_idx def gptq_marlin_moe_repack( @@ -233,7 +233,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): topk_method = layer.topk_method if topk_method == "noaux_tc": - gate_out = get_moe_scores( + gate_out, _, _ = get_moe_scores( gate_out, layer.n_group, layer.topk_group, diff --git a/test/operators/test_noaux_tc.py b/test/operators/test_noaux_tc.py new file mode 100644 index 000000000..06e065673 --- /dev/null +++ b/test/operators/test_noaux_tc.py @@ -0,0 +1,76 @@ +import unittest + +import paddle + +from fastdeploy.model_executor.ops.gpu import noaux_tc + + +class TestMoeRouting(unittest.TestCase): + def setUp(self): + self.num_tokens = 10 + self.num_experts = 64 + self.gating_output = paddle.rand([self.num_tokens, self.num_experts]) + self.e_score_correction_bias = paddle.rand([self.num_experts]) + self.n_group = 8 + self.topk_group = 4 + self.top_k = 8 + self.routed_scaling_factor = 1.5 + + def node_limit_routing(self, gate_probs): + """将所有专家分组, 只在topk_group个group内选择专家""" + assert len(gate_probs.shape) == 2 + seq_length, n_experts = gate_probs.shape + + group_scores = gate_probs.reshape([seq_length, 8, -1]).topk(2, axis=-1)[0].sum(axis=-1) + group_idx = paddle.topk(group_scores, k=4, axis=-1, sorted=True)[1] + group_mask = paddle.zeros_like(group_scores).put_along_axis( + group_idx, paddle.ones([], dtype="float32"), axis=-1 + ) + score_mask = group_mask.unsqueeze(-1).expand([seq_length, 8, n_experts // 8]).reshape([seq_length, -1]) + gate_probs = gate_probs.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) + return gate_probs + + def ref_moe_routing(self): + scores = paddle.nn.functional.sigmoid(self.gating_output) + prob_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) + prob_for_choice = self.node_limit_routing(prob_for_choice) + top_logits, topk_idx_ref = paddle.topk(prob_for_choice, self.top_k, axis=1) + + token_num, top_k = topk_idx_ref.shape + _, num_expert = prob_for_choice.shape + topk_idx_expanded = paddle.unsqueeze(topk_idx_ref, axis=-1) + indices = paddle.concat( + [ + paddle.arange(token_num, dtype="int64").unsqueeze(1).tile([1, top_k]).unsqueeze(-1), + topk_idx_expanded, + ], + axis=-1, + ) + selected_gate_probs = paddle.gather_nd(scores, indices) + + selected_gate_probs_sum = paddle.sum(selected_gate_probs, axis=1, keepdim=True) + topk_weights_ref = selected_gate_probs / selected_gate_probs_sum + topk_weights_ref = topk_weights_ref * self.routed_scaling_factor + return topk_weights_ref, topk_idx_ref + + def test_moe_select(self): + scores = paddle.nn.functional.sigmoid(self.gating_output) + scores_with_bias = scores + self.e_score_correction_bias.unsqueeze(0) + + scores, topk_values, topk_idx = noaux_tc( + scores, + scores_with_bias, + self.n_group, + self.topk_group, + self.top_k, + self.routed_scaling_factor, + ) + + ref_topk_values, ref_topk_idx = self.ref_moe_routing() + + paddle.allclose(topk_values, ref_topk_values) + paddle.allclose(topk_idx.cast(int), ref_topk_idx.cast(int)) + + +if __name__ == "__main__": + unittest.main()