optimize w4a8 decoding (#3050)

This commit is contained in:
Yuan Xiaolan
2025-07-28 22:20:13 +08:00
committed by GitHub
parent e80ea8a71b
commit 7d87aaace8
6 changed files with 253 additions and 36 deletions

View File

@@ -223,14 +223,11 @@ public:
static Status can_implement(Arguments const &args)
{
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()");
// printf("--1\n");
// Initialize static kernel and device properties, if necessary.
Status result = init_device_props();
// printf("--1-2\n");
if (result != Status::kSuccess) {
return result;
}
// printf("--2\n");
dim3 grid = get_grid_shape(args);
// printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z);
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
@@ -238,7 +235,6 @@ public:
{
return Status::kErrorInvalidProblem;
}
// printf("--3\n");
return GemmKernel::can_implement(args);
}
@@ -285,18 +281,50 @@ public:
}
/// Returns the maximum number of active thread blocks per multiprocessor
static int maximum_active_blocks()
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()");
// Initialize static device properties, if necessary
if (init_device_props() != Status::kSuccess) {
int smem_size = int(sizeof(typename GemmKernel_::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10)) {
result = cudaFuncSetAttribute(Kernel2<GemmKernel_>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error "
<< cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
Kernel2<GemmKernel_>,
GemmKernel_::kThreadCount,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
return sm_occupancy_;
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
@@ -341,8 +369,7 @@ public:
// Configure grid and block dimensions
dim3 block(GemmKernel::kThreadCount, 1, 1);
// dim3 grid = params_.get_grid_dims();
dim3 grid(216, 1, 1);
dim3 grid(params_.threadblock_count, 1, 1);
// Launch kernel
CUTLASS_TRACE_HOST(" "

View File

@@ -21,12 +21,12 @@ rm -rf up_gate_proj_7168_8192.log
rm -rf down_proj_8192_3584.log
num_experts=8
for tokens_per_expert in 12
for tokens_per_expert in 1 2 4 8 16 20 24 28 32 36 48 64 96 128 160 192 224 256 384 512 768 1024 2048 3072 4096 8192
do
wait
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 0 1 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 0 1 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
done
wait
echo "#### finish ####"

View File

@@ -996,7 +996,6 @@ int main(int argc, char *argv[]) {
CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
};
std::vector<SplitKStyle> all_split_k_style{SplitKStyle::NO_SPLIT_K};

View File

@@ -665,10 +665,139 @@ void moe_fast_hardamard_kernel(const T *x,
}
}
template <typename T, typename OutT, int kThreads, int kNBytes, int VecSize, int N,
int kNChunks, int kSmeSize, int kRounds, int kChunksPerSmemSize, bool UseDiagonalBlockMatrix = false>
__global__ __launch_bounds__(kThreads)
void masked_moe_fast_hardamard_kernel(const T *x,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float* quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
OutT *out) {
using vec_t = typename BytesToType<sizeof(T) * VecSize>::Type;
constexpr int kLogVecSize = cilog2(VecSize);
constexpr int kLogWarpSize = cilog2(32);
constexpr int kWarpSize = 32;
constexpr int kNWarps = kThreads / kWarpSize;
constexpr int kLogNWarps = cilog2(kNWarps);
constexpr int kLogNChunks = cilog2(kNChunks);
extern __shared__ char smem_[];
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_);
for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) {
const auto token_idx_in_expert = token_id % num_max_tokens_per_expert;
const auto expert_id = token_id / num_max_tokens_per_expert;
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
auto num_iters_to_next_expert = (next_expert_start_idx - token_id - 1) / gridDim.x;
token_id += num_iters_to_next_expert * gridDim.x;
continue;
}
const T *x_now = x + token_id * dim;
OutT *out_now = out + token_id * dim;
T init_value = static_cast<T>(0.f);
T x_vals[kNChunks][VecSize] = {init_value};
load_input<kNChunks, VecSize, UseDiagonalBlockMatrix, T>(x_now, x_vals, dim);
#ifdef DEBUG_HARDAMARD
if (blockIdx.x == 0 && threadIdx.x == 0) {
for (int i = 0; i < 1; ++i) {
printf("chunk_id0: %d\n", i);
for (int j = 0; j < VecSize; ++j) {
printf("%f ", (float)x_vals[i][j]);
}
printf("\n");
}
}
__syncthreads();
#endif
hadamard_mult_thread<kLogVecSize, kNChunks>(x_vals);
#ifdef DEBUG_HARDAMARD
if (blockIdx.x == 0 && threadIdx.x == 0) {
for (int i = 0; i < 1; ++i) {
printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize);
for (int j = 0; j < VecSize; ++j) {
printf("%f ", (float)x_vals[i][j]);
}
printf("\n");
}
}
__syncthreads();
#endif
hadamard_mult_warp<kLogWarpSize, 0, kNChunks, VecSize>(x_vals);
#ifdef DEBUG_HARDAMARD
if (blockIdx.x == 0 && threadIdx.x == 0) {
for (int i = 0; i < 1; ++i) {
printf("chunk_id2: %d\n", i);
for (int j = 0; j < VecSize; ++j) {
printf("%f ", (float)x_vals[i][j]);
}
printf("\n");
}
}
__syncthreads();
#endif
if constexpr (kNWarps > 1) {
// 先让连续的NWARPS个线程拿到其余warps上的数据
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, true, vec_t>(x_vals, smem_exchange);
// 交叉计算
hadamard_mult_warp<kLogNWarps, 0, kNChunks, VecSize>(x_vals);
// 再换回来
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
}
if constexpr (kNChunks > 1) {
if constexpr (kNChunks == 28) {
hadamard_mult_thread_28_transpose<T, VecSize>(x_vals);
} else if constexpr (kNChunks == 36) {
hadamard_mult_thread_36_transpose<T, VecSize>(x_vals);
} else {
constexpr int kLogNChunks = cilog2(kNChunks);
static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
hadamard_mult_thread_transpose<kLogNChunks, VecSize>(x_vals);
}
}
if (quant_scales) {
float quant_scale = quant_scales[expert_id];
if (shift) {
smooth_quant_store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T, OutT>(
out_now,
shift,
smooth,
x_vals,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
dim);
} else {
quant_store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T, OutT>(
out_now,
x_vals,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
dim);
}
} else {
store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T>(out_now, x_vals, dim);
}
}
}
template <typename T, typename OutT, int kLogN, int VecSize, int kNChunks, int kThreads, bool UseDiagonalBlockMatrix>
void MoeFastHardamardImplWrapper(const T *x,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float* quant_scales,
@@ -677,6 +806,8 @@ void MoeFastHardamardImplWrapper(const T *x,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
OutT* out,
cudaStream_t stream) {
using nv_type = typename nv_type_traits<T>::type;
@@ -696,33 +827,61 @@ void MoeFastHardamardImplWrapper(const T *x,
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
auto kernel = moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, kThreads, kSmemSize);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
if constexpr (UseDiagonalBlockMatrix) {
grid.y = ceil(dim / (kThreads * VecSize));
if (used_in_ep_low_latency) {
auto masked_kernel = masked_moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, masked_kernel, kThreads, kSmemSize);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
if constexpr (UseDiagonalBlockMatrix) {
grid.y = ceil(dim / (kThreads * VecSize));
}
masked_kernel<<<grid, kThreads, kSmemSize, stream>>>(
reinterpret_cast<const nv_type*>(x),
recv_expert_count,
reinterpret_cast<const nv_type*>(shift),
reinterpret_cast<const nv_type*>(smooth),
quant_scales,
quant_round_type,
quant_max_bound,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
reinterpret_cast<out_type*>(out)
);
} else {
auto kernel = moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, kThreads, kSmemSize);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
if constexpr (UseDiagonalBlockMatrix) {
grid.y = ceil(dim / (kThreads * VecSize));
}
kernel<<<grid, kThreads, kSmemSize, stream>>>(
reinterpret_cast<const nv_type*>(x),
expert_idx_per_token,
reinterpret_cast<const nv_type*>(shift),
reinterpret_cast<const nv_type*>(smooth),
quant_scales,
quant_round_type,
quant_max_bound,
quant_min_bound,
token_num,
dim,
reinterpret_cast<out_type*>(out)
);
}
kernel<<<grid, kThreads, kSmemSize, stream>>>(
reinterpret_cast<const nv_type*>(x),
expert_idx_per_token,
reinterpret_cast<const nv_type*>(shift),
reinterpret_cast<const nv_type*>(smooth),
quant_scales,
quant_round_type,
quant_max_bound,
quant_min_bound,
token_num,
dim,
reinterpret_cast<out_type*>(out)
);
}
template <typename T, typename OutT>
void MoeFastHardamardWrapper(const T *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float* quant_scales,
@@ -731,6 +890,8 @@ void MoeFastHardamardWrapper(const T *x_data,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
OutT* out,
cudaStream_t &stream) {
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
@@ -748,6 +909,7 @@ void MoeFastHardamardWrapper(const T *x_data,
MoeFastHardamardImplWrapper<T, OutT, kLogN, VEC_SIZE, kNChunks, kThreads, true>(
x_data,
expert_idx_per_token,
recv_expert_count,
shift,
smooth,
quant_scales,
@@ -756,6 +918,8 @@ void MoeFastHardamardWrapper(const T *x_data,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
used_in_ep_low_latency,
out,
stream);
})});
@@ -769,6 +933,7 @@ void MoeFastHardamardWrapper(const T *x_data,
MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>(
x_data,
expert_idx_per_token,
recv_expert_count,
shift,
smooth,
quant_scales,
@@ -777,6 +942,8 @@ void MoeFastHardamardWrapper(const T *x_data,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
used_in_ep_low_latency,
out,
stream);
});
@@ -789,6 +956,7 @@ void MoeFastHardamardWrapper(const T *x_data,
MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>(
x_data,
expert_idx_per_token,
recv_expert_count,
shift,
smooth,
quant_scales,
@@ -797,6 +965,8 @@ void MoeFastHardamardWrapper(const T *x_data,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
used_in_ep_low_latency,
out,
stream);
});
@@ -809,6 +979,7 @@ void MoeFastHardamardWrapper(const T *x_data,
MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>(
x_data,
expert_idx_per_token,
recv_expert_count,
shift,
smooth,
quant_scales,
@@ -817,6 +988,8 @@ void MoeFastHardamardWrapper(const T *x_data,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
used_in_ep_low_latency,
out,
stream);
});
@@ -827,6 +1000,7 @@ void MoeFastHardamardWrapper(const T *x_data,
template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
const phi::dtype::float16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::float16 *shift,
const phi::dtype::float16 *smooth,
const float* quant_scales,
@@ -835,6 +1009,8 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
phi::dtype::float16 *out,
cudaStream_t &stream
);
@@ -842,6 +1018,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
const phi::dtype::float16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::float16 *shift,
const phi::dtype::float16 *smooth,
const float* quant_scales,
@@ -850,6 +1027,8 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
int8_t *out,
cudaStream_t &stream
);
@@ -857,6 +1036,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float* quant_scales,
@@ -865,6 +1045,8 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
phi::dtype::bfloat16 *out,
cudaStream_t &stream
);
@@ -872,6 +1054,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float* quant_scales,
@@ -880,6 +1063,8 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
int8_t *out,
cudaStream_t &stream
);

View File

@@ -21,6 +21,7 @@
template <typename T, typename OutT>
void MoeFastHardamardWrapper(const T *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float* quant_scales,
@@ -29,5 +30,7 @@ void MoeFastHardamardWrapper(const T *x_data,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
OutT* out,
cudaStream_t &stream);

View File

@@ -240,6 +240,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
MoeFastHardamardWrapper<data_t, int8_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
down_proj_shift, // down_proj_shift->data<T>(),
down_proj_smooth, // down_proj_smooth->data<T>(),
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
@@ -248,6 +249,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
-127.0,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
stream
);