diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h index f871cb1d8..1301cc351 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h @@ -223,14 +223,11 @@ public: static Status can_implement(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()"); - // printf("--1\n"); // Initialize static kernel and device properties, if necessary. Status result = init_device_props(); - // printf("--1-2\n"); if (result != Status::kSuccess) { return result; } - // printf("--2\n"); dim3 grid = get_grid_shape(args); // printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z); if (!(grid.y <= std::numeric_limits::max() && @@ -238,7 +235,6 @@ public: { return Status::kErrorInvalidProblem; } - // printf("--3\n"); return GemmKernel::can_implement(args); } @@ -285,18 +281,50 @@ public: } + /// Returns the maximum number of active thread blocks per multiprocessor - static int maximum_active_blocks() + static int maximum_active_blocks(int smem_capacity = -1) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()"); - // Initialize static device properties, if necessary - if (init_device_props() != Status::kSuccess) { + int smem_size = int(sizeof(typename GemmKernel_::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel2, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel2, + GemmKernel_::kThreadCount, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); return -1; } - CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); - return sm_occupancy_; + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; } @@ -341,8 +369,7 @@ public: // Configure grid and block dimensions dim3 block(GemmKernel::kThreadCount, 1, 1); - // dim3 grid = params_.get_grid_dims(); - dim3 grid(216, 1, 1); + dim3 grid(params_.threadblock_count, 1, 1); // Launch kernel CUTLASS_TRACE_HOST(" " diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh index eb3be5fa5..f26aff8b8 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh @@ -21,12 +21,12 @@ rm -rf up_gate_proj_7168_8192.log rm -rf down_proj_8192_3584.log num_experts=8 -for tokens_per_expert in 12 +for tokens_per_expert in 1 2 4 8 16 20 24 28 32 36 48 64 96 128 160 192 224 256 384 512 768 1024 2048 3072 4096 8192 do wait -CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 & -# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 & +CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 0 1 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 & +CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 0 1 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 & done wait echo "#### finish ####" diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu index 4cdc7f0b3..76e0195af 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu @@ -996,7 +996,6 @@ int main(int argc, char *argv[]) { CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64, CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, - CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, }; std::vector all_split_k_style{SplitKStyle::NO_SPLIT_K}; diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu index 0bda82f29..7bf46f0f4 100644 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu +++ b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu @@ -665,10 +665,139 @@ void moe_fast_hardamard_kernel(const T *x, } } +template +__global__ __launch_bounds__(kThreads) +void masked_moe_fast_hardamard_kernel(const T *x, + const int64_t *recv_expert_count, + const T *shift, + const T *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + OutT *out) { + using vec_t = typename BytesToType::Type; + constexpr int kLogVecSize = cilog2(VecSize); + constexpr int kLogWarpSize = cilog2(32); + constexpr int kWarpSize = 32; + constexpr int kNWarps = kThreads / kWarpSize; + constexpr int kLogNWarps = cilog2(kNWarps); + constexpr int kLogNChunks = cilog2(kNChunks); + + extern __shared__ char smem_[]; + vec_t *smem_exchange = reinterpret_cast(smem_); + + for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { + const auto token_idx_in_expert = token_id % num_max_tokens_per_expert; + const auto expert_id = token_id / num_max_tokens_per_expert; + if (token_idx_in_expert >= recv_expert_count[expert_id]) { + auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; + auto num_iters_to_next_expert = (next_expert_start_idx - token_id - 1) / gridDim.x; + token_id += num_iters_to_next_expert * gridDim.x; + continue; + } + const T *x_now = x + token_id * dim; + OutT *out_now = out + token_id * dim; + T init_value = static_cast(0.f); + T x_vals[kNChunks][VecSize] = {init_value}; + + load_input(x_now, x_vals, dim); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id0: %d\n", i); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + hadamard_mult_thread(x_vals); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + hadamard_mult_warp(x_vals); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id2: %d\n", i); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + if constexpr (kNWarps > 1) { + // 先让连续的NWARPS个线程拿到其余warps上的数据 + exchange_smem_pre(x_vals, smem_exchange); + // 交叉计算 + hadamard_mult_warp(x_vals); + // 再换回来 + exchange_smem_pre(x_vals, smem_exchange); + } + if constexpr (kNChunks > 1) { + if constexpr (kNChunks == 28) { + hadamard_mult_thread_28_transpose(x_vals); + } else if constexpr (kNChunks == 36) { + hadamard_mult_thread_36_transpose(x_vals); + } else { + constexpr int kLogNChunks = cilog2(kNChunks); + static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); + hadamard_mult_thread_transpose(x_vals); + } + } + if (quant_scales) { + float quant_scale = quant_scales[expert_id]; + if (shift) { + smooth_quant_store_output( + out_now, + shift, + smooth, + x_vals, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + dim); + } else { + quant_store_output( + out_now, + x_vals, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + dim); + } + } else { + store_output(out_now, x_vals, dim); + } + } +} + template void MoeFastHardamardImplWrapper(const T *x, const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, const T *shift, const T *smooth, const float* quant_scales, @@ -677,6 +806,8 @@ void MoeFastHardamardImplWrapper(const T *x, const float quant_min_bound, const int64_t token_num, const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, OutT* out, cudaStream_t stream) { using nv_type = typename nv_type_traits::type; @@ -696,33 +827,61 @@ void MoeFastHardamardImplWrapper(const T *x, int sm_count; int act_blocks_per_sm; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - auto kernel = moe_fast_hardamard_kernel; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, kernel, kThreads, kSmemSize); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - dim3 grid; - grid.x = min(static_cast(num_blocks_per_wave), token_num); - if constexpr (UseDiagonalBlockMatrix) { - grid.y = ceil(dim / (kThreads * VecSize)); + + if (used_in_ep_low_latency) { + auto masked_kernel = masked_moe_fast_hardamard_kernel; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, masked_kernel, kThreads, kSmemSize); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + dim3 grid; + grid.x = min(static_cast(num_blocks_per_wave), token_num); + if constexpr (UseDiagonalBlockMatrix) { + grid.y = ceil(dim / (kThreads * VecSize)); + } + masked_kernel<<>>( + reinterpret_cast(x), + recv_expert_count, + reinterpret_cast(shift), + reinterpret_cast(smooth), + quant_scales, + quant_round_type, + quant_max_bound, + quant_min_bound, + token_num, + dim, + num_max_tokens_per_expert, + reinterpret_cast(out) + ); + } else { + auto kernel = moe_fast_hardamard_kernel; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, kernel, kThreads, kSmemSize); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + dim3 grid; + grid.x = min(static_cast(num_blocks_per_wave), token_num); + if constexpr (UseDiagonalBlockMatrix) { + grid.y = ceil(dim / (kThreads * VecSize)); + } + kernel<<>>( + reinterpret_cast(x), + expert_idx_per_token, + reinterpret_cast(shift), + reinterpret_cast(smooth), + quant_scales, + quant_round_type, + quant_max_bound, + quant_min_bound, + token_num, + dim, + reinterpret_cast(out) + ); } - kernel<<>>( - reinterpret_cast(x), - expert_idx_per_token, - reinterpret_cast(shift), - reinterpret_cast(smooth), - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - reinterpret_cast(out) - ); } template void MoeFastHardamardWrapper(const T *x_data, const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, const T *shift, const T *smooth, const float* quant_scales, @@ -731,6 +890,8 @@ void MoeFastHardamardWrapper(const T *x_data, const float quant_min_bound, const int64_t token_num, const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, OutT* out, cudaStream_t &stream) { bool FLAGS_hardamard_use_diagonal_block_matrix = true; @@ -748,6 +909,7 @@ void MoeFastHardamardWrapper(const T *x_data, MoeFastHardamardImplWrapper( x_data, expert_idx_per_token, + recv_expert_count, shift, smooth, quant_scales, @@ -756,6 +918,8 @@ void MoeFastHardamardWrapper(const T *x_data, quant_min_bound, token_num, dim, + num_max_tokens_per_expert, + used_in_ep_low_latency, out, stream); })}); @@ -769,6 +933,7 @@ void MoeFastHardamardWrapper(const T *x_data, MoeFastHardamardImplWrapper( x_data, expert_idx_per_token, + recv_expert_count, shift, smooth, quant_scales, @@ -777,6 +942,8 @@ void MoeFastHardamardWrapper(const T *x_data, quant_min_bound, token_num, dim, + num_max_tokens_per_expert, + used_in_ep_low_latency, out, stream); }); @@ -789,6 +956,7 @@ void MoeFastHardamardWrapper(const T *x_data, MoeFastHardamardImplWrapper( x_data, expert_idx_per_token, + recv_expert_count, shift, smooth, quant_scales, @@ -797,6 +965,8 @@ void MoeFastHardamardWrapper(const T *x_data, quant_min_bound, token_num, dim, + num_max_tokens_per_expert, + used_in_ep_low_latency, out, stream); }); @@ -809,6 +979,7 @@ void MoeFastHardamardWrapper(const T *x_data, MoeFastHardamardImplWrapper( x_data, expert_idx_per_token, + recv_expert_count, shift, smooth, quant_scales, @@ -817,6 +988,8 @@ void MoeFastHardamardWrapper(const T *x_data, quant_min_bound, token_num, dim, + num_max_tokens_per_expert, + used_in_ep_low_latency, out, stream); }); @@ -827,6 +1000,7 @@ void MoeFastHardamardWrapper(const T *x_data, template void MoeFastHardamardWrapper( const phi::dtype::float16 *x_data, const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, const phi::dtype::float16 *shift, const phi::dtype::float16 *smooth, const float* quant_scales, @@ -835,6 +1009,8 @@ template void MoeFastHardamardWrapper( const float quant_min_bound, const int64_t token_num, const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, phi::dtype::float16 *out, cudaStream_t &stream ); @@ -842,6 +1018,7 @@ template void MoeFastHardamardWrapper( template void MoeFastHardamardWrapper( const phi::dtype::float16 *x_data, const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, const phi::dtype::float16 *shift, const phi::dtype::float16 *smooth, const float* quant_scales, @@ -850,6 +1027,8 @@ template void MoeFastHardamardWrapper( const float quant_min_bound, const int64_t token_num, const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, int8_t *out, cudaStream_t &stream ); @@ -857,6 +1036,7 @@ template void MoeFastHardamardWrapper( template void MoeFastHardamardWrapper( const phi::dtype::bfloat16 *x_data, const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, const phi::dtype::bfloat16 *shift, const phi::dtype::bfloat16 *smooth, const float* quant_scales, @@ -865,6 +1045,8 @@ template void MoeFastHardamardWrapper( const phi::dtype::bfloat16 *x_data, const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, const phi::dtype::bfloat16 *shift, const phi::dtype::bfloat16 *smooth, const float* quant_scales, @@ -880,6 +1063,8 @@ template void MoeFastHardamardWrapper( const float quant_min_bound, const int64_t token_num, const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, int8_t *out, cudaStream_t &stream ); diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h index 77af5b7a1..64c5c20ad 100644 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h +++ b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h @@ -21,6 +21,7 @@ template void MoeFastHardamardWrapper(const T *x_data, const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, const T *shift, const T *smooth, const float* quant_scales, @@ -29,5 +30,7 @@ void MoeFastHardamardWrapper(const T *x_data, const float quant_min_bound, const int64_t token_num, const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, OutT* out, cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index 1d453466d..f9aadb494 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -240,6 +240,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, MoeFastHardamardWrapper( act_out_tensor.data(), expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, + const_cast(tokens_expert_prefix_sum.data()), down_proj_shift, // down_proj_shift->data(), down_proj_smooth, // down_proj_smooth->data(), down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, @@ -248,6 +249,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, -127.0, expanded_active_expert_rows, inter_size / 2, + num_max_tokens_per_expert, + used_in_ep_low_latency, reinterpret_cast(int8_act_out->ptr()), stream );