[noauxtc_kernel] remove useless code (#4643)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* remove num_tokens

* remove num_tokens

* false

* final commit
This commit is contained in:
周周周
2025-10-30 18:59:04 +08:00
committed by GitHub
parent ec7746bd55
commit 0089287534
2 changed files with 84 additions and 49 deletions

View File

@@ -17,8 +17,8 @@
#pragma once
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "helper.h"
#include <cuda/std/limits>
#include "helper.h"
namespace cg = cooperative_groups;
@@ -64,7 +64,9 @@ __forceinline__ __device__ bool is_better_than(T val, T baseline) {
}
template <bool greater, typename T, typename idxT>
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
__forceinline__ __device__ bool is_better_than(T val,
T baseline,
idxT index,
idxT baseline_index) {
bool res = (val > baseline && greater) || (val < baseline && !greater);
if (val == baseline) {
@@ -82,7 +84,11 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}
template <int size, bool ascending, bool reverse, typename T, typename idxT,
template <int size,
bool ascending,
bool reverse,
typename T,
typename idxT,
bool is_stable>
struct BitonicMerge {
// input should be a bitonic sequence, and sort it to be a monotonic sequence
@@ -99,8 +105,8 @@ struct BitonicMerge {
T& other_val = val_arr[other_i];
bool is_better;
if constexpr (is_stable) {
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
idx_arr[other_i]);
is_better = is_better_than<ascending>(
val, other_val, idx_arr[i], idx_arr[other_i]);
} else {
is_better = is_better_than<ascending>(val, other_val);
}
@@ -182,7 +188,10 @@ struct BitonicSort<32, ascending, T, idxT, is_stable> {
}
};
template <bool ascending, bool reverse, typename T, typename idxT,
template <bool ascending,
bool reverse,
typename T,
typename idxT,
bool is_stable>
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
__device__ static void merge(T* __restrict__ val_arr,
@@ -234,7 +243,8 @@ class WarpSort {
// load and merge k sorted values
__device__ void load_sorted(T const* __restrict__ in,
idxT const* __restrict__ in_idx, idxT start) {
idxT const* __restrict__ in_idx,
idxT start) {
idxT idx = start + WARP_SIZE - 1 - lane_;
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
if (idx < start + k_) {
@@ -456,8 +466,7 @@ __device__ void topk_with_k2(T* output,
template <typename T>
__global__ void topk_with_k2_kernel(T* output,
T* input,
int64_t const num_tokens,
const T* input,
int64_t const num_cases,
int64_t const n_group,
int64_t const num_experts_per_group) {
@@ -484,11 +493,11 @@ __global__ void topk_with_k2_kernel(T* output,
template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_kernel(
T* scores,
const T* scores,
T const* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
const T* scores_with_bias,
int64_t const num_tokens,
int64_t const n_group,
int64_t const topk_group,
@@ -550,14 +559,17 @@ __global__ void group_idx_and_topk_idx_kernel(
value = neg_inf<T>();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == neg_inf<T>())));
count_equal_to_top_value =
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
warp_topk::WarpSelect</*capability*/ WARP_SIZE,
/*greater*/ true,
T,
int32_t,
/* is_stable */ true>
queue((int32_t)topk, neg_inf<T>());
@@ -602,19 +614,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
topk_sum +=
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}
__syncthreads();
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__syncwarp();
if (case_id < num_tokens) {
if (if_proceed_next_topk) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
@@ -625,7 +631,6 @@ __global__ void group_idx_and_topk_idx_kernel(
} else {
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
scores[s_topk_idx[i]] = value;
topk_indices[i] = s_topk_idx[i];
topk_values[i] = cuda_cast<T, float>(value);
}
@@ -662,7 +667,11 @@ void invokeNoAuxTc(T* scores,
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group);
group_scores,
scores_with_bias,
num_cases,
n_group,
num_experts / n_group);
#else
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
@@ -675,8 +684,13 @@ void invokeNoAuxTc(T* scores,
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
cudaLaunchKernelEx(&config,
kernel_instance1,
group_scores,
scores_with_bias,
num_cases,
n_group,
num_experts / n_group);
#endif
int64_t topk_with_k_group_num_blocks =
@@ -686,10 +700,22 @@ void invokeNoAuxTc(T* scores,
topk);
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream>>>(
scores, group_scores, topk_values, topk_indices, scores_with_bias,
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor);
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks,
BLOCK_SIZE,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
renormalize,
routed_scaling_factor);
#else
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
@@ -700,26 +726,37 @@ void invokeNoAuxTc(T* scores,
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
cudaLaunchKernelEx(&config,
kernel_instance2,
scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
renormalize,
routed_scaling_factor);
#endif
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>(T * scores, \
T * group_scores, \
T* topk_values, \
IdxT* topk_indices, \
T * scores_with_bias, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
bool const renormalize, \
double const routed_scaling_factor, \
cudaStream_t const stream);
T * group_scores, \
T * topk_values, \
IdxT * topk_indices, \
T * scores_with_bias, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
bool const renormalize, \
double const routed_scaling_factor, \
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t);

View File

@@ -256,7 +256,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
if topk_method == "noaux_tc":
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
gate_out, _, _ = get_moe_scores(
_, topk_weights, topk_ids = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
@@ -265,8 +265,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,