mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] adapt DeepSeek (#4498)
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include "helper.h"
|
||||
#include <cuda/std/limits>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
@@ -601,7 +602,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
topk_sum += cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -658,6 +659,11 @@ void invokeNoAuxTc(T* scores,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
#else
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
@@ -671,6 +677,7 @@ void invokeNoAuxTc(T* scores,
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
#endif
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
@@ -678,6 +685,12 @@ void invokeNoAuxTc(T* scores,
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream>>>(
|
||||
scores, group_scores, topk_values, topk_indices, scores_with_bias,
|
||||
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
|
||||
renormalize, routed_scaling_factor);
|
||||
#else
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
@@ -691,6 +704,7 @@ void invokeNoAuxTc(T* scores,
|
||||
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||
n_group, topk_group, topk, num_experts,
|
||||
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
|
||||
Reference in New Issue
Block a user