diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 1d977f50a..85f88cf12 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -571,6 +571,7 @@ std::vector NoauxTc( int n_group, int topk_group, int topk, + bool renormalize, float routed_scaling_factor); #ifdef ENABLE_FP8 diff --git a/custom_ops/gpu_ops/noaux_tc.cu b/custom_ops/gpu_ops/noaux_tc.cu index 19a9e380f..7b239b8cf 100644 --- a/custom_ops/gpu_ops/noaux_tc.cu +++ b/custom_ops/gpu_ops/noaux_tc.cu @@ -26,6 +26,7 @@ std::vector NoauxTc(paddle::Tensor& scores, int n_group, int topk_group, int topk, + bool renormalize, float routed_scaling_factor) { auto input_shape = scores_with_bias.shape(); PD_CHECK(input_shape.size() == 2); @@ -48,6 +49,7 @@ std::vector NoauxTc(paddle::Tensor& scores, n_group, topk_group, topk, + renormalize, routed_scaling_factor, stream); @@ -76,6 +78,7 @@ PD_BUILD_STATIC_OP(noaux_tc) .Attrs({"n_group: int", "topk_group: int", "topk:int", + "renormalize: bool", "routed_scaling_factor: float"}) .SetKernelFn(PD_KERNEL(NoauxTc)) .SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape)) diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index e8a3f4508..392dbfe3b 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -25,6 +25,23 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr int32_t BLOCK_SIZE = 512; constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__device__ inline T neg_inf() { + // cuda::std::numeric_limits::infinity() returns `0` for [T=bf16 or fp16] + // so we need to cast from fp32 + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + namespace warp_topk { template @@ -41,10 +58,21 @@ constexpr __host__ __device__ bool isPowerOf2(T v) { } template -__device__ bool is_better_than(T val, T baseline) { +__forceinline__ __device__ bool is_better_than(T val, T baseline) { return (val > baseline && greater) || (val < baseline && !greater); } +template +__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) { + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + } + return res; +} + template int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; @@ -53,7 +81,8 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); } -template +template struct BitonicMerge { // input should be a bitonic sequence, and sort it to be a monotonic sequence __device__ static void merge(T* __restrict__ val_arr, @@ -67,7 +96,15 @@ struct BitonicMerge { int const other_i = i + stride; T& val = val_arr[i]; T& other_val = val_arr[other_i]; - if ((val > other_val && ascending) || (val < other_val && !ascending)) { + bool is_better; + if constexpr (is_stable) { + is_better = is_better_than(val, other_val, idx_arr[i], + idx_arr[other_i]); + } else { + is_better = is_better_than(val, other_val); + } + + if (is_better) { T tmp = val; val = other_val; other_val = tmp; @@ -78,13 +115,14 @@ struct BitonicMerge { } } - BitonicMerge::merge(val_arr, idx_arr); - BitonicMerge::merge(val_arr + arr_len / 2, - idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); } }; -template +template struct BitonicSort { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { @@ -92,15 +130,16 @@ struct BitonicSort { static_assert(size >= 2 * WARP_SIZE); constexpr int arr_len = size / WARP_SIZE; - BitonicSort::sort(val_arr, idx_arr); - BitonicSort::sort(val_arr + arr_len / 2, - idx_arr + arr_len / 2); - BitonicMerge::merge(val_arr, idx_arr); + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); } }; -template -struct BitonicSort<32, ascending, T, idxT> { +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { int const lane = threadIdx.x % WARP_SIZE; @@ -114,19 +153,37 @@ struct BitonicSort<32, ascending, T, idxT> { T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); - if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) { + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { *val_arr = other; *idx_arr = other_idx; } } } - BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr); + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); } }; -template -struct BitonicMerge<32, ascending, T, idxT> { +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { int const lane = threadIdx.x % WARP_SIZE; @@ -136,7 +193,24 @@ struct BitonicMerge<32, ascending, T, idxT> { T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); idxT& idx = *idx_arr; idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); - if (val != other && ((val > other) == (ascending != is_second))) { + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); // for min + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); // for max + } + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + + if (is_better) { val = other; idx = other_idx; } @@ -144,34 +218,42 @@ struct BitonicMerge<32, ascending, T, idxT> { } }; -template +template class WarpSort { -public: + public: __device__ WarpSort(idxT k, T dummy) : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); for (int i = 0; i < max_arr_len_; ++i) { val_arr_[i] = dummy_; + idx_arr_[i] = 0; } } // load and merge k sorted values __device__ void load_sorted(T const* __restrict__ in, - idxT const* __restrict__ in_idx, - idxT start) { + idxT const* __restrict__ in_idx, idxT start) { idxT idx = start + WARP_SIZE - 1 - lane_; for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { if (idx < start + k_) { T t = in[idx]; - if (is_better_than(t, val_arr_[i])) { + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(t, val_arr_[i], in_idx[idx], idx_arr_[i]); + } else { + is_better = is_better_than(t, val_arr_[i]); + } + if (is_better) { val_arr_[i] = t; idx_arr_[i] = in_idx[idx]; } } } - BitonicMerge::merge(val_arr_, idx_arr_); + BitonicMerge::merge( + val_arr_, idx_arr_); } __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { @@ -193,7 +275,7 @@ public: } } -protected: + protected: static constexpr int max_arr_len_ = capacity / WARP_SIZE; T val_arr_[max_arr_len_]; @@ -205,11 +287,11 @@ protected: }; // end class WarpSort -template -class WarpSelect : public WarpSort { -public: +template +class WarpSelect : public WarpSort { + public: __device__ WarpSelect(idxT k, T dummy) - : WarpSort(k, dummy), + : WarpSort(k, dummy), k_th_(dummy), k_th_lane_((k - 1) % WARP_SIZE) { extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; @@ -234,7 +316,13 @@ public: } __device__ void add(T val, idxT idx) { - bool do_add = is_better_than(val, k_th_); + bool do_add; + if constexpr (is_stable) { + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + } else { + do_add = is_better_than(val, k_th_); + } + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); if (mask == 0) { return; @@ -271,37 +359,52 @@ public: __syncthreads(); } -private: + private: __device__ void set_k_th_() { k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) { + k_th_idx_ = + __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } } __device__ void merge_buf_(T val, idxT idx) { - BitonicSort::sort(&val, &idx); + BitonicSort::sort(&val, &idx); T& old = val_arr_[max_arr_len_ - 1]; - if (is_better_than(val, old)) { + + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + } else { + is_better = is_better_than(val, old); + } + + if (is_better) { old = val; idx_arr_[max_arr_len_ - 1] = idx; } - BitonicMerge::merge(val_arr_, idx_arr_); + BitonicMerge::merge( + val_arr_, idx_arr_); set_k_th_(); } - using WarpSort::max_arr_len_; - using WarpSort::val_arr_; - using WarpSort::idx_arr_; - using WarpSort::lane_; - using WarpSort::k_; - using WarpSort::dummy_; + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; T* val_smem_; idxT* idx_smem_; int smem_buf_len_ = 0; T k_th_; + idxT k_th_idx_; int const k_th_lane_; }; // end class WarpSelect } // namespace warp_topk @@ -313,8 +416,8 @@ __device__ void topk_with_k2(T* output, int32_t const lane_id, int const num_experts_per_group) { // Get the top2 per thread - T largest = cuda::std::numeric_limits::min(); - T second_largest = cuda::std::numeric_limits::min(); + T largest = neg_inf(); + T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { @@ -368,8 +471,14 @@ __global__ void topk_with_k2_kernel(T* output, cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif topk_with_k2(output, input, tile, lane_id, num_experts_per_group); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } template @@ -385,6 +494,7 @@ __global__ void group_idx_and_topk_idx_kernel( int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, + bool const renormalize, double routed_scaling_factor) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; @@ -403,19 +513,29 @@ __global__ void group_idx_and_topk_idx_kernel( extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to // store the target topk idx - int32_t* s_topk_idx = reinterpret_cast(smem_buf) + warp_id * topk; + int32_t* s_topk_idx = reinterpret_cast(smem_buf); T* s_topk_value = reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + warp_id * topk; + s_topk_idx += warp_id * topk; - T value = cuda::std::numeric_limits::min(); - T topk_group_value = cuda::std::numeric_limits::min(); + T value = neg_inf(); + T topk_group_value = neg_inf(); int32_t num_equalto_topkth_group; - if ((n_group > topk_group) && (case_id < num_tokens)) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; - if (lane_id < n_group) { + if (lane_id < n_group && + (isfinite(cuda_cast( + group_scores[lane_id])))) // The check is necessary to avoid + // abnormal input + { value = group_scores[lane_id]; } @@ -426,22 +546,23 @@ __global__ void group_idx_and_topk_idx_kernel( __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { - value = cuda::std::numeric_limits::min(); + value = neg_inf(); } pre_count_equal_to_top_value = count_equal_to_top_value; count_equal_to_top_value = __popc(__ballot_sync( - FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + FULL_WARP_MASK, (value == neg_inf()))); } num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; } __syncthreads(); - warp_topk::WarpSelect - queue((int32_t)topk, cuda::std::numeric_limits::min()); + warp_topk::WarpSelect + queue((int32_t)topk, neg_inf()); int count_equalto_topkth_group = 0; - bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits::min()); - if (case_id < num_tokens) { + bool if_proceed_next_topk = (topk_group_value != neg_inf()); + if (case_id < num_tokens && if_proceed_next_topk) { for (int i_group = 0; i_group < n_group; i_group++) { if ((group_scores[i_group] > topk_group_value) || ((group_scores[i_group] == topk_group_value) && @@ -449,9 +570,11 @@ __global__ void group_idx_and_topk_idx_kernel( int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { - T candidates = i < num_experts_per_group - ? scores_with_bias[offset + i] - : cuda::std::numeric_limits::min(); + T candidates = + (i < num_experts_per_group) && isfinite(cuda_cast( + scores_with_bias[offset + i])) + ? scores_with_bias[offset + i] + : neg_inf(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { @@ -469,7 +592,7 @@ __global__ void group_idx_and_topk_idx_kernel( // Load the valid score value // Calculate the summation float topk_sum = 1e-20; - if (case_id < num_tokens) { + if (case_id < num_tokens && if_proceed_next_topk) { for (int i = lane_id; i < warp_topk::round_up_to_multiple_of(topk); i += WARP_SIZE) { @@ -478,33 +601,45 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { s_topk_value[i] = value; } - topk_sum += reduce(tile, value, cg::plus()); + topk_sum += reduce(tile, cuda_cast(value), cg::plus()); } } __syncthreads(); - if (case_id < num_tokens) { + + if (case_id < num_tokens && if_proceed_next_topk) { for (int i = lane_id; i < num_experts; i += WARP_SIZE) { scores[i] = 0; } } - __threadfence(); - __syncthreads(); + __syncwarp(); if (case_id < num_tokens) { - for (int i = lane_id; i < topk; i += WARP_SIZE) { - float value = s_topk_value[i] / topk_sum * routed_scaling_factor; - scores[s_topk_idx[i]] = value; - if (if_proceed_next_topk) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + scores[s_topk_idx[i]] = value; topk_indices[i] = s_topk_idx[i]; - topk_values[i] = static_cast(value); + topk_values[i] = cuda_cast(value); } - else { + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { topk_indices[i] = i; - topk_values[i] = static_cast(1.0f / topk); + topk_values[i] = cuda_cast(1.0f / topk); } } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } template @@ -518,17 +653,24 @@ void invokeNoAuxTc(T* scores, int64_t const n_group, int64_t const topk_group, int64_t const topk, + bool const renormalize, double const routed_scaling_factor, cudaStream_t const stream) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; - topk_with_k2_kernel<<>>( - group_scores, - scores_with_bias, - num_tokens, - num_cases, - n_group, - num_experts / n_group); + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, + num_tokens, num_cases, n_group, num_experts / n_group); int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; @@ -536,21 +678,19 @@ void invokeNoAuxTc(T* scores, warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, topk); - group_idx_and_topk_idx_kernel<<>>(scores, - group_scores, - topk_values, - topk_indices, - scores_with_bias, - num_tokens, - n_group, - topk_group, - topk, - num_experts, - num_experts / n_group, - routed_scaling_factor); + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, scores_with_bias, num_tokens, + n_group, topk_group, topk, num_experts, + num_experts / n_group, renormalize, routed_scaling_factor); } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ @@ -564,6 +704,7 @@ void invokeNoAuxTc(T* scores, int64_t const n_group, \ int64_t const topk_group, \ int64_t const topk, \ + bool const renormalize, \ double const routed_scaling_factor, \ cudaStream_t const stream); diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 61f3fca94..910c5dd87 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -369,6 +369,7 @@ class EPRunner: layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 1bce4d6b7..eaa46448c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -39,6 +39,7 @@ elif current_platform.is_iluvatar(): moe_expert_reduce, ) +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs @@ -226,15 +227,14 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): """ gate_out = gate(x.cast("float32")) if layer.topk_method == "noaux_tc": - from fastdeploy.model_executor.layers.moe.moe import get_moe_scores - - gate_out, _, _ = get_moe_scores( + gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, layer.topk_group, layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) ( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 142500234..c973f1901 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -512,6 +512,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) else: topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index 566df0351..cc09dd3ad 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -263,6 +263,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 1dc88dbcc..282658cd8 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -263,8 +263,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) - topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False) else: topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 1d44aa713..ddd7a4aea 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -66,6 +66,7 @@ def get_moe_scores( top_k, routed_scaling_factor, e_score_correction_bias, + renormalize: bool = False, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. @@ -79,6 +80,7 @@ def get_moe_scores( n_group if n_group > 0 else 1, topk_group if topk_group > 0 else 1, top_k, + renormalize, routed_scaling_factor, ) return scores, topk_values, topk_idx @@ -93,6 +95,7 @@ class FusedMoE(nn.Layer): self, fd_config, reduce_results: bool = True, + renormalize: bool = False, moe_intermediate_size: int = -1, num_experts: int = -1, expert_id_offset: int = 0, @@ -119,6 +122,7 @@ class FusedMoE(nn.Layer): self.fd_config = fd_config self.layer_idx = layer_idx self.reduce_results = reduce_results + self.renormalize = renormalize self.tp_rank = fd_config.parallel_config.tensor_parallel_rank self.tp_size = fd_config.parallel_config.tensor_parallel_size self.ep_size = fd_config.parallel_config.expert_parallel_size diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 1c79b381f..c2045cb82 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -121,6 +121,7 @@ class DeepSeekV3MoE(nn.Layer): super().__init__() self.tp_size = fd_config.parallel_config.tensor_parallel_size + self.norm_topk_prob = fd_config.model_config.norm_topk_prob weight_key_map = { "gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias", @@ -150,6 +151,7 @@ class DeepSeekV3MoE(nn.Layer): self.experts = FusedMoE( fd_config=fd_config, reduce_results=False, + renormalize=self.norm_topk_prob, moe_intermediate_size=fd_config.model_config.moe_intermediate_size, num_experts=fd_config.model_config.n_routed_experts, top_k=fd_config.model_config.num_experts_per_tok, diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index a2a4c4bda..1f7ab8827 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -110,6 +110,8 @@ class Glm4Moe(nn.Layer): self.n_routed_experts: int = fd_config.model_config.n_routed_experts self.n_shared_experts: int = fd_config.model_config.n_shared_experts + self.norm_topk_prob = fd_config.model_config.norm_topk_prob + weight_key_map = { "gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias", "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", @@ -134,6 +136,7 @@ class Glm4Moe(nn.Layer): self.experts = FusedMoE( fd_config, reduce_results=False, + renormalize=self.norm_topk_prob, moe_intermediate_size=fd_config.model_config.moe_intermediate_size, num_experts=fd_config.model_config.n_routed_experts, top_k=fd_config.model_config.num_experts_per_tok, diff --git a/tests/operators/test_noaux_tc.py b/tests/operators/test_noaux_tc.py index 06e065673..b0193ac3c 100644 --- a/tests/operators/test_noaux_tc.py +++ b/tests/operators/test_noaux_tc.py @@ -2,74 +2,103 @@ import unittest import paddle -from fastdeploy.model_executor.ops.gpu import noaux_tc +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores class TestMoeRouting(unittest.TestCase): def setUp(self): - self.num_tokens = 10 - self.num_experts = 64 - self.gating_output = paddle.rand([self.num_tokens, self.num_experts]) - self.e_score_correction_bias = paddle.rand([self.num_experts]) - self.n_group = 8 - self.topk_group = 4 - self.top_k = 8 - self.routed_scaling_factor = 1.5 + paddle.seed(2024) + print(paddle.device.cuda.get_device_properties()) + print(paddle.__git_commit__) - def node_limit_routing(self, gate_probs): - """将所有专家分组, 只在topk_group个group内选择专家""" - assert len(gate_probs.shape) == 2 - seq_length, n_experts = gate_probs.shape + def native_group_topk( + self, + gating_output: paddle.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + routed_scaling_factor: float, + e_score_correction_bias: paddle.Tensor, + ): + original_scores = paddle.nn.functional.sigmoid(gating_output) + if len(e_score_correction_bias.shape) == 1: + e_score_correction_bias = e_score_correction_bias.unsqueeze(0) + scores = original_scores + e_score_correction_bias - group_scores = gate_probs.reshape([seq_length, 8, -1]).topk(2, axis=-1)[0].sum(axis=-1) - group_idx = paddle.topk(group_scores, k=4, axis=-1, sorted=True)[1] - group_mask = paddle.zeros_like(group_scores).put_along_axis( - group_idx, paddle.ones([], dtype="float32"), axis=-1 + num_token, n_experts = scores.shape + group_scores = scores.reshape([num_token, num_expert_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores) # [n, n_group] + group_mask.put_along_axis_(group_idx, 1.0, axis=-1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand([num_token, num_expert_group, n_experts // num_expert_group]) + .reshape([num_token, -1]) ) - score_mask = group_mask.unsqueeze(-1).expand([seq_length, 8, n_experts // 8]).reshape([seq_length, -1]) - gate_probs = gate_probs.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) - return gate_probs + tmp_scores = scores.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) - def ref_moe_routing(self): - scores = paddle.nn.functional.sigmoid(self.gating_output) - prob_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) - prob_for_choice = self.node_limit_routing(prob_for_choice) - top_logits, topk_idx_ref = paddle.topk(prob_for_choice, self.top_k, axis=1) + topk_ids = paddle.topk(tmp_scores, topk, axis=1)[1] + topk_weights = paddle.take_along_axis(original_scores, topk_ids, axis=1) - token_num, top_k = topk_idx_ref.shape - _, num_expert = prob_for_choice.shape - topk_idx_expanded = paddle.unsqueeze(topk_idx_ref, axis=-1) - indices = paddle.concat( - [ - paddle.arange(token_num, dtype="int64").unsqueeze(1).tile([1, top_k]).unsqueeze(-1), - topk_idx_expanded, - ], - axis=-1, - ) - selected_gate_probs = paddle.gather_nd(scores, indices) + if renormalize: + topk_weights = topk_weights / paddle.sum(topk_weights, axis=1, keepdim=True) - selected_gate_probs_sum = paddle.sum(selected_gate_probs, axis=1, keepdim=True) - topk_weights_ref = selected_gate_probs / selected_gate_probs_sum - topk_weights_ref = topk_weights_ref * self.routed_scaling_factor - return topk_weights_ref, topk_idx_ref + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor - def test_moe_select(self): - scores = paddle.nn.functional.sigmoid(self.gating_output) - scores_with_bias = scores + self.e_score_correction_bias.unsqueeze(0) + return topk_weights, topk_ids - scores, topk_values, topk_idx = noaux_tc( - scores, - scores_with_bias, - self.n_group, - self.topk_group, - self.top_k, - self.routed_scaling_factor, - ) + def test_group_topk(self): - ref_topk_values, ref_topk_idx = self.ref_moe_routing() + renormalize = True - paddle.allclose(topk_values, ref_topk_values) - paddle.allclose(topk_idx.cast(int), ref_topk_idx.cast(int)) + test_cases = [ + # (num_experts, n_group, topk_group, top_k, routed_scaling_factor) + (128, 1, 1, 8, 1.0), # glm45-air + (256, 8, 4, 8, 2.5), # deepseek + ] + + for case_tuple in test_cases: + num_experts, n_group, topk_group, top_k, routed_scaling_factor = case_tuple + for num_tokens in [1, 32, 64, 128]: + gating_output = paddle.rand([num_tokens, num_experts]) + e_score_correction_bias = paddle.rand([1, num_experts]) + + ref_topk_values, ref_topk_idx = self.native_group_topk( + gating_output=gating_output, + topk=top_k, + renormalize=renormalize, + num_expert_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + new_score, topk_values, topk_idx = get_moe_scores( + gating_output=gating_output, + n_group=n_group, + topk_group=topk_group, + top_k=top_k, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + renormalize=renormalize, + ) + + equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item() + equal_topk_ids = paddle.allclose( + topk_idx.cast("int32"), ref_topk_idx.cast("int32"), atol=0.0, rtol=0.0 + ).item() + print( + f"Test Case[{case_tuple}], num_tokens = {num_tokens}, equal_topk_value: {equal_topk_value}, equal_topk_ids: {equal_topk_ids}" + ) + if not equal_topk_value: + print(f"ref_topk_values = {ref_topk_values}") + print(f"topk_values = {topk_values}") + if not equal_topk_ids: + print(f"ref_topk_idx = {ref_topk_idx}") + print(f"topk_idx = {topk_idx}") + assert equal_topk_value and equal_topk_ids if __name__ == "__main__":