[Metax] adapt DeepSeek (#4498)

This commit is contained in:
xiaozude
2025-10-24 10:14:53 +08:00
committed by GitHub
parent 8718fa34b2
commit f7069b8057
19 changed files with 1538 additions and 324 deletions

View File

@@ -18,6 +18,7 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "helper.h"
#include <cuda/std/limits>
namespace cg = cooperative_groups;
@@ -601,7 +602,7 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
topk_sum += cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}
@@ -658,6 +659,11 @@ void invokeNoAuxTc(T* scores,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group);
#else
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks;
@@ -671,6 +677,7 @@ void invokeNoAuxTc(T* scores,
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
#endif
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
@@ -678,6 +685,12 @@ void invokeNoAuxTc(T* scores,
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream>>>(
scores, group_scores, topk_values, topk_indices, scores_with_bias,
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor);
#else
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
@@ -691,6 +704,7 @@ void invokeNoAuxTc(T* scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
#endif
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \