mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support noaux for eplb (#5143)
* support noaux eplb * noaux_eplb * noaux_eplb * noaux_eplb
This commit is contained in:
@@ -420,6 +420,13 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
|
||||
inline __device__ unsigned int xorwow_moe(unsigned int& state) {
|
||||
state ^= state >> 7;
|
||||
state ^= state << 9;
|
||||
state ^= state >> 13;
|
||||
return state;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output,
|
||||
T const* input,
|
||||
@@ -656,6 +663,195 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
__global__ void group_idx_and_topk_idx_redundant_kernel(
|
||||
T* scores,
|
||||
T const* group_scores,
|
||||
T* topk_values,
|
||||
IdxT* topk_indices,
|
||||
T* scores_with_bias,
|
||||
int32_t* expert_id_to_ep_rank_array,
|
||||
int32_t* expert_in_rank_num_list,
|
||||
int32_t* tokens_per_expert_stats_list,
|
||||
int64_t const num_tokens,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
bool const renormalize,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
double routed_scaling_factor,
|
||||
int64_t const redundant_ep_rank_num_plus_one) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
unsigned int state = case_id;
|
||||
scores_with_bias += case_id * num_experts;
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
topk_values += case_id * topk;
|
||||
topk_indices += case_id * topk;
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value =
|
||||
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE,
|
||||
/*greater*/ true,
|
||||
T,
|
||||
int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
}
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value = i < topk ? scores[s_topk_idx[i]]
|
||||
: 0.0f; // Load the valid value of expert
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum +=
|
||||
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// Note(ZKK): a little trick.
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value;
|
||||
if (renormalize) {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||
routed_scaling_factor;
|
||||
} else {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||
}
|
||||
scores[s_topk_idx[i]] = value;
|
||||
|
||||
int expert_topk = s_topk_idx[i];
|
||||
int len = expert_in_rank_num_list[expert_topk];
|
||||
int select = (int)xorwow_moe(state) % len;
|
||||
// int select = 0;
|
||||
int selected_rank =
|
||||
expert_id_to_ep_rank_array[expert_topk *
|
||||
redundant_ep_rank_num_plus_one +
|
||||
select];
|
||||
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
|
||||
topk_indices[i] = (IdxT)selected_rank;
|
||||
topk_values[i] = cuda_cast<T, float>(value);
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
int expert_topk = i;
|
||||
int len = expert_in_rank_num_list[expert_topk];
|
||||
int select = (int)xorwow_moe(state) % len;
|
||||
// int select = 0;
|
||||
int selected_rank =
|
||||
expert_id_to_ep_rank_array[expert_topk *
|
||||
redundant_ep_rank_num_plus_one +
|
||||
select];
|
||||
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
|
||||
topk_indices[i] = (IdxT)selected_rank;
|
||||
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void invokeNoAuxTc(T* scores,
|
||||
T* group_scores,
|
||||
@@ -752,6 +948,111 @@ void invokeNoAuxTc(T* scores,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void invokeNoAuxTcRedundant(T* scores,
|
||||
T* group_scores,
|
||||
T* topk_values,
|
||||
IdxT* topk_indices,
|
||||
T* scores_with_bias,
|
||||
int32_t* expert_id_to_ep_rank_array,
|
||||
int32_t* expert_in_rank_num_list,
|
||||
int32_t* tokens_per_expert_stats_list,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_experts,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
bool const renormalize,
|
||||
double const routed_scaling_factor,
|
||||
int64_t const redundant_ep_rank_num_plus_one,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
#else
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config,
|
||||
kernel_instance1,
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
#endif
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
group_idx_and_topk_idx_redundant_kernel<T>
|
||||
<<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor,
|
||||
redundant_ep_rank_num_plus_one);
|
||||
#else
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config,
|
||||
kernel_instance2,
|
||||
scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
renormalize,
|
||||
routed_scaling_factor);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
template void invokeNoAuxTc<T, IdxT>(T * scores, \
|
||||
T * group_scores, \
|
||||
@@ -768,3 +1069,25 @@ void invokeNoAuxTc(T* scores,
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, int32_t);
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC_Redundant(T, IdxT) \
|
||||
template void invokeNoAuxTcRedundant<T, IdxT>( \
|
||||
T * scores, \
|
||||
T * group_scores, \
|
||||
T * topk_values, \
|
||||
IdxT * topk_indices, \
|
||||
T * scores_with_bias, \
|
||||
int32_t * expert_id_to_ep_rank_array, \
|
||||
int32_t * expert_in_rank_num_list, \
|
||||
int32_t * tokens_per_expert_stats_list, \
|
||||
int64_t const num_tokens, \
|
||||
int64_t const num_experts, \
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
bool const renormalize, \
|
||||
double const routed_scaling_factor, \
|
||||
int64_t const redundant_ep_rank_num_plus_one, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC_Redundant(float, int32_t);
|
||||
|
||||
Reference in New Issue
Block a user