mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
optimize w4a8 decoding (#3050)
This commit is contained in:
@@ -665,10 +665,139 @@ void moe_fast_hardamard_kernel(const T *x,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename OutT, int kThreads, int kNBytes, int VecSize, int N,
|
||||
int kNChunks, int kSmeSize, int kRounds, int kChunksPerSmemSize, bool UseDiagonalBlockMatrix = false>
|
||||
__global__ __launch_bounds__(kThreads)
|
||||
void masked_moe_fast_hardamard_kernel(const T *x,
|
||||
const int64_t *recv_expert_count,
|
||||
const T *shift,
|
||||
const T *smooth,
|
||||
const float* quant_scales,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
OutT *out) {
|
||||
using vec_t = typename BytesToType<sizeof(T) * VecSize>::Type;
|
||||
constexpr int kLogVecSize = cilog2(VecSize);
|
||||
constexpr int kLogWarpSize = cilog2(32);
|
||||
constexpr int kWarpSize = 32;
|
||||
constexpr int kNWarps = kThreads / kWarpSize;
|
||||
constexpr int kLogNWarps = cilog2(kNWarps);
|
||||
constexpr int kLogNChunks = cilog2(kNChunks);
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_);
|
||||
|
||||
for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) {
|
||||
const auto token_idx_in_expert = token_id % num_max_tokens_per_expert;
|
||||
const auto expert_id = token_id / num_max_tokens_per_expert;
|
||||
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
|
||||
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
|
||||
auto num_iters_to_next_expert = (next_expert_start_idx - token_id - 1) / gridDim.x;
|
||||
token_id += num_iters_to_next_expert * gridDim.x;
|
||||
continue;
|
||||
}
|
||||
const T *x_now = x + token_id * dim;
|
||||
OutT *out_now = out + token_id * dim;
|
||||
T init_value = static_cast<T>(0.f);
|
||||
T x_vals[kNChunks][VecSize] = {init_value};
|
||||
|
||||
load_input<kNChunks, VecSize, UseDiagonalBlockMatrix, T>(x_now, x_vals, dim);
|
||||
#ifdef DEBUG_HARDAMARD
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0) {
|
||||
for (int i = 0; i < 1; ++i) {
|
||||
printf("chunk_id0: %d\n", i);
|
||||
for (int j = 0; j < VecSize; ++j) {
|
||||
printf("%f ", (float)x_vals[i][j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
hadamard_mult_thread<kLogVecSize, kNChunks>(x_vals);
|
||||
#ifdef DEBUG_HARDAMARD
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0) {
|
||||
for (int i = 0; i < 1; ++i) {
|
||||
printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize);
|
||||
for (int j = 0; j < VecSize; ++j) {
|
||||
printf("%f ", (float)x_vals[i][j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
hadamard_mult_warp<kLogWarpSize, 0, kNChunks, VecSize>(x_vals);
|
||||
#ifdef DEBUG_HARDAMARD
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0) {
|
||||
for (int i = 0; i < 1; ++i) {
|
||||
printf("chunk_id2: %d\n", i);
|
||||
for (int j = 0; j < VecSize; ++j) {
|
||||
printf("%f ", (float)x_vals[i][j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
if constexpr (kNWarps > 1) {
|
||||
// 先让连续的NWARPS个线程拿到其余warps上的数据
|
||||
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, true, vec_t>(x_vals, smem_exchange);
|
||||
// 交叉计算
|
||||
hadamard_mult_warp<kLogNWarps, 0, kNChunks, VecSize>(x_vals);
|
||||
// 再换回来
|
||||
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
|
||||
}
|
||||
if constexpr (kNChunks > 1) {
|
||||
if constexpr (kNChunks == 28) {
|
||||
hadamard_mult_thread_28_transpose<T, VecSize>(x_vals);
|
||||
} else if constexpr (kNChunks == 36) {
|
||||
hadamard_mult_thread_36_transpose<T, VecSize>(x_vals);
|
||||
} else {
|
||||
constexpr int kLogNChunks = cilog2(kNChunks);
|
||||
static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
|
||||
hadamard_mult_thread_transpose<kLogNChunks, VecSize>(x_vals);
|
||||
}
|
||||
}
|
||||
if (quant_scales) {
|
||||
float quant_scale = quant_scales[expert_id];
|
||||
if (shift) {
|
||||
smooth_quant_store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T, OutT>(
|
||||
out_now,
|
||||
shift,
|
||||
smooth,
|
||||
x_vals,
|
||||
quant_scale,
|
||||
quant_round_type,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
dim);
|
||||
} else {
|
||||
quant_store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T, OutT>(
|
||||
out_now,
|
||||
x_vals,
|
||||
quant_scale,
|
||||
quant_round_type,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
dim);
|
||||
}
|
||||
} else {
|
||||
store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T>(out_now, x_vals, dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename OutT, int kLogN, int VecSize, int kNChunks, int kThreads, bool UseDiagonalBlockMatrix>
|
||||
void MoeFastHardamardImplWrapper(const T *x,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const T *shift,
|
||||
const T *smooth,
|
||||
const float* quant_scales,
|
||||
@@ -677,6 +806,8 @@ void MoeFastHardamardImplWrapper(const T *x,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
OutT* out,
|
||||
cudaStream_t stream) {
|
||||
using nv_type = typename nv_type_traits<T>::type;
|
||||
@@ -696,33 +827,61 @@ void MoeFastHardamardImplWrapper(const T *x,
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
auto kernel = moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, kThreads, kSmemSize);
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
dim3 grid;
|
||||
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||
if constexpr (UseDiagonalBlockMatrix) {
|
||||
grid.y = ceil(dim / (kThreads * VecSize));
|
||||
|
||||
if (used_in_ep_low_latency) {
|
||||
auto masked_kernel = masked_moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, masked_kernel, kThreads, kSmemSize);
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
dim3 grid;
|
||||
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||
if constexpr (UseDiagonalBlockMatrix) {
|
||||
grid.y = ceil(dim / (kThreads * VecSize));
|
||||
}
|
||||
masked_kernel<<<grid, kThreads, kSmemSize, stream>>>(
|
||||
reinterpret_cast<const nv_type*>(x),
|
||||
recv_expert_count,
|
||||
reinterpret_cast<const nv_type*>(shift),
|
||||
reinterpret_cast<const nv_type*>(smooth),
|
||||
quant_scales,
|
||||
quant_round_type,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
num_max_tokens_per_expert,
|
||||
reinterpret_cast<out_type*>(out)
|
||||
);
|
||||
} else {
|
||||
auto kernel = moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, kThreads, kSmemSize);
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
dim3 grid;
|
||||
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||
if constexpr (UseDiagonalBlockMatrix) {
|
||||
grid.y = ceil(dim / (kThreads * VecSize));
|
||||
}
|
||||
kernel<<<grid, kThreads, kSmemSize, stream>>>(
|
||||
reinterpret_cast<const nv_type*>(x),
|
||||
expert_idx_per_token,
|
||||
reinterpret_cast<const nv_type*>(shift),
|
||||
reinterpret_cast<const nv_type*>(smooth),
|
||||
quant_scales,
|
||||
quant_round_type,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
reinterpret_cast<out_type*>(out)
|
||||
);
|
||||
}
|
||||
kernel<<<grid, kThreads, kSmemSize, stream>>>(
|
||||
reinterpret_cast<const nv_type*>(x),
|
||||
expert_idx_per_token,
|
||||
reinterpret_cast<const nv_type*>(shift),
|
||||
reinterpret_cast<const nv_type*>(smooth),
|
||||
quant_scales,
|
||||
quant_round_type,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
reinterpret_cast<out_type*>(out)
|
||||
);
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const T *shift,
|
||||
const T *smooth,
|
||||
const float* quant_scales,
|
||||
@@ -731,6 +890,8 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
OutT* out,
|
||||
cudaStream_t &stream) {
|
||||
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
|
||||
@@ -748,6 +909,7 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
MoeFastHardamardImplWrapper<T, OutT, kLogN, VEC_SIZE, kNChunks, kThreads, true>(
|
||||
x_data,
|
||||
expert_idx_per_token,
|
||||
recv_expert_count,
|
||||
shift,
|
||||
smooth,
|
||||
quant_scales,
|
||||
@@ -756,6 +918,8 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
out,
|
||||
stream);
|
||||
})});
|
||||
@@ -769,6 +933,7 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>(
|
||||
x_data,
|
||||
expert_idx_per_token,
|
||||
recv_expert_count,
|
||||
shift,
|
||||
smooth,
|
||||
quant_scales,
|
||||
@@ -777,6 +942,8 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
out,
|
||||
stream);
|
||||
});
|
||||
@@ -789,6 +956,7 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>(
|
||||
x_data,
|
||||
expert_idx_per_token,
|
||||
recv_expert_count,
|
||||
shift,
|
||||
smooth,
|
||||
quant_scales,
|
||||
@@ -797,6 +965,8 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
out,
|
||||
stream);
|
||||
});
|
||||
@@ -809,6 +979,7 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>(
|
||||
x_data,
|
||||
expert_idx_per_token,
|
||||
recv_expert_count,
|
||||
shift,
|
||||
smooth,
|
||||
quant_scales,
|
||||
@@ -817,6 +988,8 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
out,
|
||||
stream);
|
||||
});
|
||||
@@ -827,6 +1000,7 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
||||
const phi::dtype::float16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::float16 *shift,
|
||||
const phi::dtype::float16 *smooth,
|
||||
const float* quant_scales,
|
||||
@@ -835,6 +1009,8 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
phi::dtype::float16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -842,6 +1018,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
||||
template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
||||
const phi::dtype::float16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::float16 *shift,
|
||||
const phi::dtype::float16 *smooth,
|
||||
const float* quant_scales,
|
||||
@@ -850,6 +1027,8 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -857,6 +1036,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
||||
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>(
|
||||
const phi::dtype::bfloat16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::bfloat16 *shift,
|
||||
const phi::dtype::bfloat16 *smooth,
|
||||
const float* quant_scales,
|
||||
@@ -865,6 +1045,8 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
phi::dtype::bfloat16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -872,6 +1054,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
|
||||
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
|
||||
const phi::dtype::bfloat16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::bfloat16 *shift,
|
||||
const phi::dtype::bfloat16 *smooth,
|
||||
const float* quant_scales,
|
||||
@@ -880,6 +1063,8 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user