mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support noaux for eplb (#5143)
* support noaux eplb * noaux_eplb * noaux_eplb * noaux_eplb
This commit is contained in:
@@ -647,6 +647,19 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor);
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTcRedundant(
|
||||
paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
paddle::Tensor& expert_id_to_ep_rank_array,
|
||||
paddle::Tensor& expert_in_rank_num_list,
|
||||
paddle::Tensor& tokens_per_expert_stats_list,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor,
|
||||
int redundant_ep_rank_num_plus_one);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
||||
const paddle::Tensor& x,
|
||||
@@ -1485,6 +1498,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
||||
|
||||
m.def("noaux_tc_redundant",
|
||||
&NoauxTcRedundant,
|
||||
"noaux_tc_redundant for MoE compute");
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
m.def("cutlass_fp8_fp8_half_gemm_fused",
|
||||
&cutlass_fp8_fp8_half_gemm_func,
|
||||
|
||||
103
custom_ops/gpu_ops/noaux_tc_redundant.cu
Normal file
103
custom_ops/gpu_ops/noaux_tc_redundant.cu
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "helper.h"
|
||||
#include "noauxtc_kernel.h"
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTcRedundant(
|
||||
paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
paddle::Tensor& expert_id_to_ep_rank_array,
|
||||
paddle::Tensor& expert_in_rank_num_list,
|
||||
paddle::Tensor& tokens_per_expert_stats_list,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor,
|
||||
int redundant_ep_rank_num_plus_one) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
PD_CHECK(input_shape.size() == 2);
|
||||
int64_t num_tokens = input_shape[0];
|
||||
int64_t num_experts = input_shape[1];
|
||||
auto input_type = scores_with_bias.dtype();
|
||||
auto place = scores_with_bias.place();
|
||||
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
|
||||
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
|
||||
auto topk_indices =
|
||||
paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
|
||||
auto stream = scores_with_bias.stream();
|
||||
|
||||
invokeNoAuxTcRedundant<float, int64_t>(
|
||||
reinterpret_cast<float*>(scores.data<float>()),
|
||||
reinterpret_cast<float*>(group_scores.data<float>()),
|
||||
reinterpret_cast<float*>(topk_values.data<float>()),
|
||||
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
|
||||
reinterpret_cast<float*>(scores_with_bias.data<float>()),
|
||||
reinterpret_cast<int*>(expert_id_to_ep_rank_array.data<int>()),
|
||||
reinterpret_cast<int*>(expert_in_rank_num_list.data<int>()),
|
||||
reinterpret_cast<int*>(tokens_per_expert_stats_list.data<int>()),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
redundant_ep_rank_num_plus_one,
|
||||
stream);
|
||||
|
||||
return {scores, topk_values, topk_indices};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> NoauxTcRedundantInferDtype(
|
||||
const paddle::DataType& scores_dtype,
|
||||
const paddle::DataType& scores_with_bias_dtype) {
|
||||
return {scores_dtype, scores_dtype, paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> NoauxTcRedundantInferShape(
|
||||
const std::vector<int64_t>& scores_shape,
|
||||
const std::vector<int64_t>&,
|
||||
const int topk) {
|
||||
auto num_tokens = scores_shape[0];
|
||||
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
|
||||
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
|
||||
return {scores_shape, topk_values_shape, topk_indices_shape};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(noaux_tc_redundant)
|
||||
.Inputs({"scores",
|
||||
"scores_with_bias",
|
||||
"expert_id_to_ep_rank_array",
|
||||
"expert_in_rank_num_list",
|
||||
"tokens_per_expert_stats_list"})
|
||||
.Outputs({"output_tensor",
|
||||
"topk_values",
|
||||
"topk_indices",
|
||||
"tokens_per_expert_stats_list_out"})
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"renormalize: bool",
|
||||
"routed_scaling_factor: float",
|
||||
"redundant_ep_rank_num_plus_one:int"})
|
||||
.SetInplaceMap({{"tokens_per_expert_stats_list",
|
||||
"tokens_per_expert_stats_list_out"}})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTcRedundant))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcRedundantInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcRedundantInferDtype));
|
||||
@@ -420,6 +420,13 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
|
||||
inline __device__ unsigned int xorwow_moe(unsigned int& state) {
|
||||
state ^= state >> 7;
|
||||
state ^= state << 9;
|
||||
state ^= state >> 13;
|
||||
return state;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output,
|
||||
T const* input,
|
||||
@@ -656,6 +663,195 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
__global__ void group_idx_and_topk_idx_redundant_kernel(
|
||||
T* scores,
|
||||
T const* group_scores,
|
||||
T* topk_values,
|
||||
IdxT* topk_indices,
|
||||
T* scores_with_bias,
|
||||
int32_t* expert_id_to_ep_rank_array,
|
||||
int32_t* expert_in_rank_num_list,
|
||||
int32_t* tokens_per_expert_stats_list,
|
||||
int64_t const num_tokens,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
bool const renormalize,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
double routed_scaling_factor,
|
||||
int64_t const redundant_ep_rank_num_plus_one) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
unsigned int state = case_id;
|
||||
scores_with_bias += case_id * num_experts;
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
topk_values += case_id * topk;
|
||||
topk_indices += case_id * topk;
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value =
|
||||
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE,
|
||||
/*greater*/ true,
|
||||
T,
|
||||
int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
}
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value = i < topk ? scores[s_topk_idx[i]]
|
||||
: 0.0f; // Load the valid value of expert
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum +=
|
||||
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// Note(ZKK): a little trick.
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value;
|
||||
if (renormalize) {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||
routed_scaling_factor;
|
||||
} else {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||
}
|
||||
scores[s_topk_idx[i]] = value;
|
||||
|
||||
int expert_topk = s_topk_idx[i];
|
||||
int len = expert_in_rank_num_list[expert_topk];
|
||||
int select = (int)xorwow_moe(state) % len;
|
||||
// int select = 0;
|
||||
int selected_rank =
|
||||
expert_id_to_ep_rank_array[expert_topk *
|
||||
redundant_ep_rank_num_plus_one +
|
||||
select];
|
||||
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
|
||||
topk_indices[i] = (IdxT)selected_rank;
|
||||
topk_values[i] = cuda_cast<T, float>(value);
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
int expert_topk = i;
|
||||
int len = expert_in_rank_num_list[expert_topk];
|
||||
int select = (int)xorwow_moe(state) % len;
|
||||
// int select = 0;
|
||||
int selected_rank =
|
||||
expert_id_to_ep_rank_array[expert_topk *
|
||||
redundant_ep_rank_num_plus_one +
|
||||
select];
|
||||
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
|
||||
topk_indices[i] = (IdxT)selected_rank;
|
||||
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void invokeNoAuxTc(T* scores,
|
||||
T* group_scores,
|
||||
@@ -752,6 +948,111 @@ void invokeNoAuxTc(T* scores,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void invokeNoAuxTcRedundant(T* scores,
|
||||
T* group_scores,
|
||||
T* topk_values,
|
||||
IdxT* topk_indices,
|
||||
T* scores_with_bias,
|
||||
int32_t* expert_id_to_ep_rank_array,
|
||||
int32_t* expert_in_rank_num_list,
|
||||
int32_t* tokens_per_expert_stats_list,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_experts,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
bool const renormalize,
|
||||
double const routed_scaling_factor,
|
||||
int64_t const redundant_ep_rank_num_plus_one,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
#else
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config,
|
||||
kernel_instance1,
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
#endif
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
group_idx_and_topk_idx_redundant_kernel<T>
|
||||
<<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor,
|
||||
redundant_ep_rank_num_plus_one);
|
||||
#else
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config,
|
||||
kernel_instance2,
|
||||
scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
renormalize,
|
||||
routed_scaling_factor);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
template void invokeNoAuxTc<T, IdxT>(T * scores, \
|
||||
T * group_scores, \
|
||||
@@ -768,3 +1069,25 @@ void invokeNoAuxTc(T* scores,
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, int32_t);
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC_Redundant(T, IdxT) \
|
||||
template void invokeNoAuxTcRedundant<T, IdxT>( \
|
||||
T * scores, \
|
||||
T * group_scores, \
|
||||
T * topk_values, \
|
||||
IdxT * topk_indices, \
|
||||
T * scores_with_bias, \
|
||||
int32_t * expert_id_to_ep_rank_array, \
|
||||
int32_t * expert_in_rank_num_list, \
|
||||
int32_t * tokens_per_expert_stats_list, \
|
||||
int64_t const num_tokens, \
|
||||
int64_t const num_experts, \
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
bool const renormalize, \
|
||||
double const routed_scaling_factor, \
|
||||
int64_t const redundant_ep_rank_num_plus_one, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC_Redundant(float, int32_t);
|
||||
|
||||
@@ -301,6 +301,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||
"gpu_ops/noaux_tc.cu",
|
||||
"gpu_ops/noaux_tc_redundant.cu",
|
||||
"gpu_ops/custom_all_reduce/all_reduce.cu",
|
||||
"gpu_ops/merge_prefill_decode_output.cu",
|
||||
"gpu_ops/limit_thinking_content_length_v1.cu",
|
||||
@@ -614,6 +615,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/share_external_data.cu",
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
"gpu_ops/noaux_tc.cu",
|
||||
"gpu_ops/noaux_tc_redundant.cu",
|
||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||
"gpu_ops/text_image_gather_scatter.cu",
|
||||
"gpu_ops/text_image_index_out.cu",
|
||||
|
||||
@@ -431,17 +431,34 @@ class EPRunner:
|
||||
tokens_per_expert_stats_list,
|
||||
) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx)
|
||||
|
||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
|
||||
gating_logits=gate_out,
|
||||
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list=expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||
bias=layer.gate_correction_bias,
|
||||
moe_topk=self.top_k,
|
||||
apply_norm_weight=True,
|
||||
enable_softmax_top_k_fused=False,
|
||||
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||
)
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .moe import get_moe_scores
|
||||
|
||||
score, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list=expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||
)
|
||||
else:
|
||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
|
||||
gating_logits=gate_out,
|
||||
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list=expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||
bias=layer.gate_correction_bias,
|
||||
moe_topk=self.top_k,
|
||||
apply_norm_weight=True,
|
||||
enable_softmax_top_k_fused=False,
|
||||
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||
)
|
||||
else:
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
@@ -27,7 +27,7 @@ from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
import numpy as np
|
||||
@@ -74,6 +74,10 @@ def get_moe_scores(
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
renormalize: bool = False,
|
||||
expert_id_to_ep_rank_array: paddle.Tensor = None,
|
||||
expert_in_rank_num_list: paddle.Tensor = None,
|
||||
tokens_per_expert_stats_list: paddle.Tensor = None,
|
||||
redundant_ep_rank_num_plus_one: int = 1,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
@@ -81,15 +85,30 @@ def get_moe_scores(
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
if expert_id_to_ep_rank_array is None:
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
scores, topk_values, topk_idx, _ = noaux_tc_redundant(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
redundant_ep_rank_num_plus_one,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
@@ -196,6 +215,7 @@ class FusedMoE(nn.Layer):
|
||||
self.quant_method = get_moe_method()
|
||||
assert self.quant_method is not None, "self.quant_method should not be None"
|
||||
self.redundant_table_manger = redundant_table_manger
|
||||
self.is_rearrange = False
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
@@ -438,7 +458,7 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
]
|
||||
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
|
||||
if self.redundant_table_manger is not None:
|
||||
if self.redundant_table_manger is not None and is_rearrange is True:
|
||||
(
|
||||
ep_rank_to_expert_id_list,
|
||||
expert_id_to_ep_rank_array,
|
||||
|
||||
@@ -211,7 +211,7 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
|
||||
def update_state_dict(self, state_dict):
|
||||
self.fused_moe.load_state_dict(state_dict, True)
|
||||
self.experts.load_state_dict(state_dict, True)
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor):
|
||||
out = self.experts(hidden_states, self.gate)
|
||||
|
||||
111
tests/operators/test_noaux_tc_redundant.py
Normal file
111
tests/operators/test_noaux_tc_redundant.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
|
||||
class TestMoeRouting(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.seed(2024)
|
||||
print(paddle.device.cuda.get_device_properties())
|
||||
print(paddle.__git_commit__)
|
||||
|
||||
def native_group_topk(
|
||||
self,
|
||||
gating_output: paddle.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
routed_scaling_factor: float,
|
||||
e_score_correction_bias: paddle.Tensor,
|
||||
):
|
||||
original_scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
if len(e_score_correction_bias.shape) == 1:
|
||||
e_score_correction_bias = e_score_correction_bias.unsqueeze(0)
|
||||
scores = original_scores + e_score_correction_bias
|
||||
|
||||
num_token, n_experts = scores.shape
|
||||
group_scores = scores.reshape([num_token, num_expert_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
|
||||
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group]
|
||||
group_mask = paddle.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.put_along_axis_(group_idx, 1.0, axis=-1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand([num_token, num_expert_group, n_experts // num_expert_group])
|
||||
.reshape([num_token, -1])
|
||||
)
|
||||
tmp_scores = scores.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
|
||||
|
||||
topk_ids = paddle.topk(tmp_scores, topk, axis=1)[1]
|
||||
topk_weights = paddle.take_along_axis(original_scores, topk_ids, axis=1)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / paddle.sum(topk_weights, axis=1, keepdim=True)
|
||||
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def test_group_topk(self):
|
||||
|
||||
renormalize = True
|
||||
|
||||
test_cases = [
|
||||
# (num_experts, n_group, topk_group, top_k, routed_scaling_factor)
|
||||
(128, 1, 1, 8, 1.0), # glm45-air
|
||||
(256, 8, 4, 8, 2.5), # deepseek
|
||||
]
|
||||
|
||||
for case_tuple in test_cases:
|
||||
num_experts, n_group, topk_group, top_k, routed_scaling_factor = case_tuple
|
||||
for num_tokens in [1, 32, 64, 128]:
|
||||
gating_output = paddle.rand([num_tokens, num_experts])
|
||||
e_score_correction_bias = paddle.rand([1, num_experts])
|
||||
expert_id_to_ep_rank_array = paddle.arange(num_experts, dtype="int32").reshape([num_experts, 1])
|
||||
expert_in_rank_num_list = paddle.ones([num_experts, 1], dtype="int32")
|
||||
tokens_per_expert_stats_list = paddle.arange(num_experts, dtype="int32").reshape([num_experts, 1])
|
||||
|
||||
ref_topk_values, ref_topk_idx = self.native_group_topk(
|
||||
gating_output=gating_output,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=n_group,
|
||||
topk_group=topk_group,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
new_score, topk_values, topk_idx = get_moe_scores(
|
||||
gating_output=gating_output,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
top_k=top_k,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
renormalize=renormalize,
|
||||
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list=expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||
)
|
||||
|
||||
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
|
||||
equal_topk_ids = paddle.allclose(
|
||||
topk_idx.cast("int32"), ref_topk_idx.cast("int32"), atol=0.0, rtol=0.0
|
||||
).item()
|
||||
print(
|
||||
f"Test Case[{case_tuple}], num_tokens = {num_tokens}, equal_topk_value: {equal_topk_value}, equal_topk_ids: {equal_topk_ids}"
|
||||
)
|
||||
if not equal_topk_value:
|
||||
print(f"ref_topk_values = {ref_topk_values}")
|
||||
print(f"topk_values = {topk_values}")
|
||||
if not equal_topk_ids:
|
||||
print(f"ref_topk_idx = {ref_topk_idx}")
|
||||
print(f"topk_idx = {topk_idx}")
|
||||
assert equal_topk_value and equal_topk_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user