mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	optimize w4a8 decoding (#3050)
This commit is contained in:
		| @@ -223,14 +223,11 @@ public: | |||||||
|   static Status can_implement(Arguments const &args) |   static Status can_implement(Arguments const &args) | ||||||
|   { |   { | ||||||
|     CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()"); |     CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()"); | ||||||
|     // printf("--1\n"); |  | ||||||
|     // Initialize static kernel and device properties, if necessary. |     // Initialize static kernel and device properties, if necessary. | ||||||
|     Status result = init_device_props(); |     Status result = init_device_props(); | ||||||
|     // printf("--1-2\n"); |  | ||||||
|     if (result != Status::kSuccess) { |     if (result != Status::kSuccess) { | ||||||
|       return result; |       return result; | ||||||
|     } |     } | ||||||
|     // printf("--2\n"); |  | ||||||
|     dim3 grid = get_grid_shape(args); |     dim3 grid = get_grid_shape(args); | ||||||
|     // printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z); |     // printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z); | ||||||
|     if (!(grid.y <= std::numeric_limits<uint16_t>::max() && |     if (!(grid.y <= std::numeric_limits<uint16_t>::max() && | ||||||
| @@ -238,7 +235,6 @@ public: | |||||||
|     { |     { | ||||||
|       return Status::kErrorInvalidProblem; |       return Status::kErrorInvalidProblem; | ||||||
|     } |     } | ||||||
|     // printf("--3\n"); |  | ||||||
|     return GemmKernel::can_implement(args); |     return GemmKernel::can_implement(args); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @@ -285,18 +281,50 @@ public: | |||||||
|   } |   } | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   /// Returns the maximum number of active thread blocks per multiprocessor |   /// Returns the maximum number of active thread blocks per multiprocessor | ||||||
|   static int maximum_active_blocks() |   static int maximum_active_blocks(int smem_capacity = -1) | ||||||
|   { |   { | ||||||
|     CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()"); |     CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()"); | ||||||
|  |  | ||||||
|     // Initialize static device properties, if necessary |     int smem_size = int(sizeof(typename GemmKernel_::SharedStorage)); | ||||||
|     if (init_device_props() != Status::kSuccess) { |  | ||||||
|  |     CUTLASS_TRACE_HOST("  smem_size: " << smem_size << " bytes"); | ||||||
|  |  | ||||||
|  |     cudaError_t result; | ||||||
|  |     if (smem_size > (48 << 10)) { | ||||||
|  |       result = cudaFuncSetAttribute(Kernel2<GemmKernel_>, | ||||||
|  |                                     cudaFuncAttributeMaxDynamicSharedMemorySize, | ||||||
|  |                                     smem_size); | ||||||
|  |  | ||||||
|  |       if (result != cudaSuccess) { | ||||||
|  |         // Call cudaGetLastError() to clear the error bit | ||||||
|  |         result = cudaGetLastError(); | ||||||
|  |         CUTLASS_TRACE_HOST( | ||||||
|  |           "  cudaFuncSetAttribute() returned error " | ||||||
|  |           << cudaGetErrorString(result)); | ||||||
|  |         return -1; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     int max_active_blocks = -1; | ||||||
|  |     result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||||||
|  |         &max_active_blocks, | ||||||
|  |         Kernel2<GemmKernel_>, | ||||||
|  |         GemmKernel_::kThreadCount, | ||||||
|  |         smem_size); | ||||||
|  |  | ||||||
|  |     if (result != cudaSuccess) { | ||||||
|  |       // Call cudaGetLastError() to clear the error bit | ||||||
|  |       result = cudaGetLastError(); | ||||||
|  |       CUTLASS_TRACE_HOST( | ||||||
|  |         "  cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " | ||||||
|  |         << cudaGetErrorString(result)); | ||||||
|       return -1; |       return -1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     CUTLASS_TRACE_HOST("  max_active_blocks: " << sm_occupancy_); |     CUTLASS_TRACE_HOST("  max_active_blocks: " << max_active_blocks); | ||||||
|     return sm_occupancy_; |     return max_active_blocks; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -341,8 +369,7 @@ public: | |||||||
|  |  | ||||||
|     // Configure grid and block dimensions |     // Configure grid and block dimensions | ||||||
|     dim3 block(GemmKernel::kThreadCount, 1, 1); |     dim3 block(GemmKernel::kThreadCount, 1, 1); | ||||||
|     // dim3 grid = params_.get_grid_dims(); |     dim3 grid(params_.threadblock_count, 1, 1); | ||||||
|         dim3 grid(216, 1, 1); |  | ||||||
|  |  | ||||||
|     // Launch kernel |     // Launch kernel | ||||||
|     CUTLASS_TRACE_HOST("  " |     CUTLASS_TRACE_HOST("  " | ||||||
|   | |||||||
| @@ -21,12 +21,12 @@ rm -rf up_gate_proj_7168_8192.log | |||||||
| rm -rf down_proj_8192_3584.log | rm -rf down_proj_8192_3584.log | ||||||
| num_experts=8 | num_experts=8 | ||||||
|  |  | ||||||
| for tokens_per_expert in 12 | for tokens_per_expert in 1 2 4 8 16 20 24 28 32 36 48 64 96 128 160 192 224 256 384 512 768 1024 2048 3072 4096 8192 | ||||||
|  |  | ||||||
| do | do | ||||||
| wait | wait | ||||||
| CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 & | CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 0 1 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 & | ||||||
| # CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 & | CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 0 1 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 & | ||||||
| done | done | ||||||
| wait | wait | ||||||
| echo "#### finish ####" | echo "#### finish ####" | ||||||
|   | |||||||
| @@ -996,7 +996,6 @@ int main(int argc, char *argv[]) { | |||||||
|         CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64, |         CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64, | ||||||
|         CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, |         CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, | ||||||
|         CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, |         CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, | ||||||
|         CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, |  | ||||||
|     }; |     }; | ||||||
|     std::vector<SplitKStyle> all_split_k_style{SplitKStyle::NO_SPLIT_K}; |     std::vector<SplitKStyle> all_split_k_style{SplitKStyle::NO_SPLIT_K}; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -665,10 +665,139 @@ void moe_fast_hardamard_kernel(const T *x, | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <typename T, typename OutT, int kThreads, int kNBytes, int VecSize, int N, | ||||||
|  |           int kNChunks, int kSmeSize, int kRounds, int kChunksPerSmemSize, bool UseDiagonalBlockMatrix = false> | ||||||
|  | __global__ __launch_bounds__(kThreads) | ||||||
|  | void masked_moe_fast_hardamard_kernel(const T *x, | ||||||
|  |                                const int64_t *recv_expert_count, | ||||||
|  |                                const T *shift, | ||||||
|  |                                const T *smooth, | ||||||
|  |                                const float* quant_scales, | ||||||
|  |                                const int quant_round_type, | ||||||
|  |                                const float quant_max_bound, | ||||||
|  |                                const float quant_min_bound, | ||||||
|  |                                const int64_t token_num, | ||||||
|  |                                const int64_t dim, | ||||||
|  |                                const int num_max_tokens_per_expert, | ||||||
|  |                                OutT *out) { | ||||||
|  |   using vec_t = typename BytesToType<sizeof(T) * VecSize>::Type; | ||||||
|  |   constexpr int kLogVecSize = cilog2(VecSize); | ||||||
|  |   constexpr int kLogWarpSize = cilog2(32); | ||||||
|  |   constexpr int kWarpSize = 32; | ||||||
|  |   constexpr int kNWarps = kThreads / kWarpSize; | ||||||
|  |   constexpr int kLogNWarps = cilog2(kNWarps); | ||||||
|  |   constexpr int kLogNChunks = cilog2(kNChunks); | ||||||
|  |  | ||||||
|  |   extern __shared__ char smem_[]; | ||||||
|  |   vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_); | ||||||
|  |  | ||||||
|  |   for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { | ||||||
|  |     const auto token_idx_in_expert = token_id % num_max_tokens_per_expert; | ||||||
|  |     const auto expert_id = token_id / num_max_tokens_per_expert; | ||||||
|  |     if (token_idx_in_expert >= recv_expert_count[expert_id]) { | ||||||
|  |         auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; | ||||||
|  |         auto num_iters_to_next_expert = (next_expert_start_idx - token_id - 1) / gridDim.x; | ||||||
|  |         token_id += num_iters_to_next_expert * gridDim.x; | ||||||
|  |         continue; | ||||||
|  |     } | ||||||
|  |     const T *x_now = x + token_id * dim; | ||||||
|  |     OutT *out_now = out + token_id * dim; | ||||||
|  |     T init_value = static_cast<T>(0.f); | ||||||
|  |     T x_vals[kNChunks][VecSize] = {init_value}; | ||||||
|  |  | ||||||
|  |     load_input<kNChunks, VecSize, UseDiagonalBlockMatrix, T>(x_now, x_vals, dim); | ||||||
|  | #ifdef DEBUG_HARDAMARD | ||||||
|  |     if (blockIdx.x == 0 && threadIdx.x == 0) { | ||||||
|  |       for (int i = 0; i < 1; ++i) { | ||||||
|  |           printf("chunk_id0: %d\n", i); | ||||||
|  |           for (int j = 0; j < VecSize; ++j) { | ||||||
|  |               printf("%f ", (float)x_vals[i][j]); | ||||||
|  |           } | ||||||
|  |           printf("\n"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     __syncthreads(); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |     hadamard_mult_thread<kLogVecSize, kNChunks>(x_vals); | ||||||
|  | #ifdef DEBUG_HARDAMARD | ||||||
|  |     if (blockIdx.x == 0 && threadIdx.x == 0) { | ||||||
|  |       for (int i = 0; i < 1; ++i) { | ||||||
|  |           printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); | ||||||
|  |           for (int j = 0; j < VecSize; ++j) { | ||||||
|  |               printf("%f ", (float)x_vals[i][j]); | ||||||
|  |           } | ||||||
|  |           printf("\n"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     __syncthreads(); | ||||||
|  | #endif | ||||||
|  |     hadamard_mult_warp<kLogWarpSize, 0, kNChunks, VecSize>(x_vals); | ||||||
|  | #ifdef DEBUG_HARDAMARD | ||||||
|  |     if (blockIdx.x == 0 && threadIdx.x == 0) { | ||||||
|  |       for (int i = 0; i < 1; ++i) { | ||||||
|  |           printf("chunk_id2: %d\n", i); | ||||||
|  |           for (int j = 0; j < VecSize; ++j) { | ||||||
|  |               printf("%f ", (float)x_vals[i][j]); | ||||||
|  |           } | ||||||
|  |           printf("\n"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     __syncthreads(); | ||||||
|  | #endif | ||||||
|  |     if constexpr (kNWarps > 1) { | ||||||
|  |         // 先让连续的NWARPS个线程拿到其余warps上的数据 | ||||||
|  |         exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, true, vec_t>(x_vals, smem_exchange); | ||||||
|  |         // 交叉计算 | ||||||
|  |         hadamard_mult_warp<kLogNWarps, 0, kNChunks, VecSize>(x_vals); | ||||||
|  |         // 再换回来 | ||||||
|  |         exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange); | ||||||
|  |     } | ||||||
|  |     if constexpr (kNChunks > 1) { | ||||||
|  |       if constexpr (kNChunks == 28) { | ||||||
|  |         hadamard_mult_thread_28_transpose<T, VecSize>(x_vals); | ||||||
|  |       } else if constexpr (kNChunks == 36) { | ||||||
|  |         hadamard_mult_thread_36_transpose<T, VecSize>(x_vals); | ||||||
|  |       } else { | ||||||
|  |         constexpr int kLogNChunks = cilog2(kNChunks); | ||||||
|  |         static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); | ||||||
|  |         hadamard_mult_thread_transpose<kLogNChunks, VecSize>(x_vals); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     if (quant_scales) { | ||||||
|  |       float quant_scale = quant_scales[expert_id]; | ||||||
|  |       if (shift) { | ||||||
|  |         smooth_quant_store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T, OutT>( | ||||||
|  |           out_now, | ||||||
|  |           shift, | ||||||
|  |           smooth, | ||||||
|  |           x_vals, | ||||||
|  |           quant_scale, | ||||||
|  |           quant_round_type, | ||||||
|  |           quant_max_bound, | ||||||
|  |           quant_min_bound, | ||||||
|  |           dim); | ||||||
|  |       } else { | ||||||
|  |         quant_store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T, OutT>( | ||||||
|  |           out_now, | ||||||
|  |           x_vals, | ||||||
|  |           quant_scale, | ||||||
|  |           quant_round_type, | ||||||
|  |           quant_max_bound, | ||||||
|  |           quant_min_bound, | ||||||
|  |           dim); | ||||||
|  |       } | ||||||
|  |     } else { | ||||||
|  |       store_output<kNChunks, VecSize, UseDiagonalBlockMatrix, T>(out_now, x_vals, dim); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
| template <typename T, typename OutT, int kLogN, int VecSize, int kNChunks, int kThreads, bool UseDiagonalBlockMatrix> | template <typename T, typename OutT, int kLogN, int VecSize, int kNChunks, int kThreads, bool UseDiagonalBlockMatrix> | ||||||
| void MoeFastHardamardImplWrapper(const T *x, | void MoeFastHardamardImplWrapper(const T *x, | ||||||
|                               const int64_t *expert_idx_per_token, |                               const int64_t *expert_idx_per_token, | ||||||
|  |                               const int64_t *recv_expert_count, | ||||||
|                               const T *shift, |                               const T *shift, | ||||||
|                               const T *smooth, |                               const T *smooth, | ||||||
|                               const float* quant_scales, |                               const float* quant_scales, | ||||||
| @@ -677,6 +806,8 @@ void MoeFastHardamardImplWrapper(const T *x, | |||||||
|                               const float quant_min_bound, |                               const float quant_min_bound, | ||||||
|                               const int64_t token_num, |                               const int64_t token_num, | ||||||
|                               const int64_t dim, |                               const int64_t dim, | ||||||
|  |                               const int num_max_tokens_per_expert, | ||||||
|  |                               bool used_in_ep_low_latency, | ||||||
|                               OutT* out, |                               OutT* out, | ||||||
|                              cudaStream_t stream) { |                              cudaStream_t stream) { | ||||||
|   using nv_type = typename nv_type_traits<T>::type; |   using nv_type = typename nv_type_traits<T>::type; | ||||||
| @@ -696,33 +827,61 @@ void MoeFastHardamardImplWrapper(const T *x, | |||||||
|   int sm_count; |   int sm_count; | ||||||
|   int act_blocks_per_sm; |   int act_blocks_per_sm; | ||||||
|   cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); |   cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); | ||||||
|   auto kernel = moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>; |  | ||||||
|   cudaOccupancyMaxActiveBlocksPerMultiprocessor( |   if (used_in_ep_low_latency) { | ||||||
|       &act_blocks_per_sm, kernel, kThreads, kSmemSize); |     auto masked_kernel = masked_moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>; | ||||||
|   const int num_blocks_per_wave = sm_count * act_blocks_per_sm; |     cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||||||
|   dim3 grid; |         &act_blocks_per_sm, masked_kernel, kThreads, kSmemSize); | ||||||
|   grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num); |     const int num_blocks_per_wave = sm_count * act_blocks_per_sm; | ||||||
|   if constexpr (UseDiagonalBlockMatrix) { |     dim3 grid; | ||||||
|     grid.y = ceil(dim / (kThreads * VecSize)); |     grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num); | ||||||
|  |     if constexpr (UseDiagonalBlockMatrix) { | ||||||
|  |       grid.y = ceil(dim / (kThreads * VecSize)); | ||||||
|  |     } | ||||||
|  |     masked_kernel<<<grid, kThreads, kSmemSize, stream>>>( | ||||||
|  |       reinterpret_cast<const nv_type*>(x), | ||||||
|  |       recv_expert_count, | ||||||
|  |       reinterpret_cast<const nv_type*>(shift), | ||||||
|  |       reinterpret_cast<const nv_type*>(smooth), | ||||||
|  |       quant_scales, | ||||||
|  |       quant_round_type, | ||||||
|  |       quant_max_bound, | ||||||
|  |       quant_min_bound, | ||||||
|  |       token_num, | ||||||
|  |       dim, | ||||||
|  |       num_max_tokens_per_expert, | ||||||
|  |       reinterpret_cast<out_type*>(out) | ||||||
|  |     ); | ||||||
|  |   } else { | ||||||
|  |     auto kernel = moe_fast_hardamard_kernel<nv_type, out_type, kThreads, kNBytes, VecSize, N, kNChunks, kSmemSize, kRounds, kChunksPerSmemSize, UseDiagonalBlockMatrix>; | ||||||
|  |     cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||||||
|  |         &act_blocks_per_sm, kernel, kThreads, kSmemSize); | ||||||
|  |     const int num_blocks_per_wave = sm_count * act_blocks_per_sm; | ||||||
|  |     dim3 grid; | ||||||
|  |     grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num); | ||||||
|  |     if constexpr (UseDiagonalBlockMatrix) { | ||||||
|  |       grid.y = ceil(dim / (kThreads * VecSize)); | ||||||
|  |     } | ||||||
|  |     kernel<<<grid, kThreads, kSmemSize, stream>>>( | ||||||
|  |       reinterpret_cast<const nv_type*>(x), | ||||||
|  |       expert_idx_per_token, | ||||||
|  |       reinterpret_cast<const nv_type*>(shift), | ||||||
|  |       reinterpret_cast<const nv_type*>(smooth), | ||||||
|  |       quant_scales, | ||||||
|  |       quant_round_type, | ||||||
|  |       quant_max_bound, | ||||||
|  |       quant_min_bound, | ||||||
|  |       token_num, | ||||||
|  |       dim, | ||||||
|  |       reinterpret_cast<out_type*>(out) | ||||||
|  |     ); | ||||||
|   } |   } | ||||||
|   kernel<<<grid, kThreads, kSmemSize, stream>>>( |  | ||||||
|     reinterpret_cast<const nv_type*>(x), |  | ||||||
|     expert_idx_per_token, |  | ||||||
|     reinterpret_cast<const nv_type*>(shift), |  | ||||||
|     reinterpret_cast<const nv_type*>(smooth), |  | ||||||
|     quant_scales, |  | ||||||
|     quant_round_type, |  | ||||||
|     quant_max_bound, |  | ||||||
|     quant_min_bound, |  | ||||||
|     token_num, |  | ||||||
|     dim, |  | ||||||
|     reinterpret_cast<out_type*>(out) |  | ||||||
|   ); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T, typename OutT> | template <typename T, typename OutT> | ||||||
| void MoeFastHardamardWrapper(const T *x_data, | void MoeFastHardamardWrapper(const T *x_data, | ||||||
|                           const int64_t *expert_idx_per_token, |                           const int64_t *expert_idx_per_token, | ||||||
|  |                           const int64_t *recv_expert_count, | ||||||
|                           const T *shift, |                           const T *shift, | ||||||
|                           const T *smooth, |                           const T *smooth, | ||||||
|                           const float* quant_scales, |                           const float* quant_scales, | ||||||
| @@ -731,6 +890,8 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|                           const float quant_min_bound, |                           const float quant_min_bound, | ||||||
|                           const int64_t token_num, |                           const int64_t token_num, | ||||||
|                           const int64_t dim, |                           const int64_t dim, | ||||||
|  |                           const int num_max_tokens_per_expert, | ||||||
|  |                           bool used_in_ep_low_latency, | ||||||
|                           OutT* out, |                           OutT* out, | ||||||
|                           cudaStream_t &stream) { |                           cudaStream_t &stream) { | ||||||
|   bool FLAGS_hardamard_use_diagonal_block_matrix = true; |   bool FLAGS_hardamard_use_diagonal_block_matrix = true; | ||||||
| @@ -748,6 +909,7 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|         MoeFastHardamardImplWrapper<T, OutT, kLogN, VEC_SIZE, kNChunks, kThreads, true>( |         MoeFastHardamardImplWrapper<T, OutT, kLogN, VEC_SIZE, kNChunks, kThreads, true>( | ||||||
|           x_data, |           x_data, | ||||||
|           expert_idx_per_token, |           expert_idx_per_token, | ||||||
|  |           recv_expert_count, | ||||||
|           shift, |           shift, | ||||||
|           smooth, |           smooth, | ||||||
|           quant_scales, |           quant_scales, | ||||||
| @@ -756,6 +918,8 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|           quant_min_bound, |           quant_min_bound, | ||||||
|           token_num, |           token_num, | ||||||
|           dim, |           dim, | ||||||
|  |           num_max_tokens_per_expert, | ||||||
|  |           used_in_ep_low_latency, | ||||||
|           out, |           out, | ||||||
|           stream); |           stream); | ||||||
|       })}); |       })}); | ||||||
| @@ -769,6 +933,7 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|         MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>( |         MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>( | ||||||
|           x_data, |           x_data, | ||||||
|           expert_idx_per_token, |           expert_idx_per_token, | ||||||
|  |           recv_expert_count, | ||||||
|           shift, |           shift, | ||||||
|           smooth, |           smooth, | ||||||
|           quant_scales, |           quant_scales, | ||||||
| @@ -777,6 +942,8 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|           quant_min_bound, |           quant_min_bound, | ||||||
|           token_num, |           token_num, | ||||||
|           dim, |           dim, | ||||||
|  |           num_max_tokens_per_expert, | ||||||
|  |           used_in_ep_low_latency, | ||||||
|           out, |           out, | ||||||
|           stream); |           stream); | ||||||
|       }); |       }); | ||||||
| @@ -789,6 +956,7 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|         MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>( |         MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>( | ||||||
|           x_data, |           x_data, | ||||||
|           expert_idx_per_token, |           expert_idx_per_token, | ||||||
|  |           recv_expert_count, | ||||||
|           shift, |           shift, | ||||||
|           smooth, |           smooth, | ||||||
|           quant_scales, |           quant_scales, | ||||||
| @@ -797,6 +965,8 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|           quant_min_bound, |           quant_min_bound, | ||||||
|           token_num, |           token_num, | ||||||
|           dim, |           dim, | ||||||
|  |           num_max_tokens_per_expert, | ||||||
|  |           used_in_ep_low_latency, | ||||||
|           out, |           out, | ||||||
|           stream); |           stream); | ||||||
|       }); |       }); | ||||||
| @@ -809,6 +979,7 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|         MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>( |         MoeFastHardamardImplWrapper<T, OutT, kLogN, VecSize, kNChunks, kThreads, false>( | ||||||
|           x_data, |           x_data, | ||||||
|           expert_idx_per_token, |           expert_idx_per_token, | ||||||
|  |           recv_expert_count, | ||||||
|           shift, |           shift, | ||||||
|           smooth, |           smooth, | ||||||
|           quant_scales, |           quant_scales, | ||||||
| @@ -817,6 +988,8 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|           quant_min_bound, |           quant_min_bound, | ||||||
|           token_num, |           token_num, | ||||||
|           dim, |           dim, | ||||||
|  |           num_max_tokens_per_expert, | ||||||
|  |           used_in_ep_low_latency, | ||||||
|           out, |           out, | ||||||
|           stream); |           stream); | ||||||
|       }); |       }); | ||||||
| @@ -827,6 +1000,7 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
| template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>( | template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>( | ||||||
|   const phi::dtype::float16 *x_data, |   const phi::dtype::float16 *x_data, | ||||||
|   const int64_t *expert_idx_per_token, |   const int64_t *expert_idx_per_token, | ||||||
|  |   const int64_t *recv_expert_count, | ||||||
|   const phi::dtype::float16 *shift, |   const phi::dtype::float16 *shift, | ||||||
|   const phi::dtype::float16 *smooth, |   const phi::dtype::float16 *smooth, | ||||||
|   const float* quant_scales, |   const float* quant_scales, | ||||||
| @@ -835,6 +1009,8 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>( | |||||||
|   const float quant_min_bound, |   const float quant_min_bound, | ||||||
|   const int64_t token_num, |   const int64_t token_num, | ||||||
|   const int64_t dim, |   const int64_t dim, | ||||||
|  |   const int num_max_tokens_per_expert, | ||||||
|  |   bool used_in_ep_low_latency, | ||||||
|   phi::dtype::float16 *out, |   phi::dtype::float16 *out, | ||||||
|   cudaStream_t &stream |   cudaStream_t &stream | ||||||
| ); | ); | ||||||
| @@ -842,6 +1018,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>( | |||||||
| template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>( | template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>( | ||||||
|   const phi::dtype::float16 *x_data, |   const phi::dtype::float16 *x_data, | ||||||
|   const int64_t *expert_idx_per_token, |   const int64_t *expert_idx_per_token, | ||||||
|  |   const int64_t *recv_expert_count, | ||||||
|   const phi::dtype::float16 *shift, |   const phi::dtype::float16 *shift, | ||||||
|   const phi::dtype::float16 *smooth, |   const phi::dtype::float16 *smooth, | ||||||
|   const float* quant_scales, |   const float* quant_scales, | ||||||
| @@ -850,6 +1027,8 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>( | |||||||
|   const float quant_min_bound, |   const float quant_min_bound, | ||||||
|   const int64_t token_num, |   const int64_t token_num, | ||||||
|   const int64_t dim, |   const int64_t dim, | ||||||
|  |   const int num_max_tokens_per_expert, | ||||||
|  |   bool used_in_ep_low_latency, | ||||||
|   int8_t *out, |   int8_t *out, | ||||||
|   cudaStream_t &stream |   cudaStream_t &stream | ||||||
| ); | ); | ||||||
| @@ -857,6 +1036,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>( | |||||||
| template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>( | template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>( | ||||||
|   const phi::dtype::bfloat16 *x_data, |   const phi::dtype::bfloat16 *x_data, | ||||||
|   const int64_t *expert_idx_per_token, |   const int64_t *expert_idx_per_token, | ||||||
|  |   const int64_t *recv_expert_count, | ||||||
|   const phi::dtype::bfloat16 *shift, |   const phi::dtype::bfloat16 *shift, | ||||||
|   const phi::dtype::bfloat16 *smooth, |   const phi::dtype::bfloat16 *smooth, | ||||||
|   const float* quant_scales, |   const float* quant_scales, | ||||||
| @@ -865,6 +1045,8 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16 | |||||||
|   const float quant_min_bound, |   const float quant_min_bound, | ||||||
|   const int64_t token_num, |   const int64_t token_num, | ||||||
|   const int64_t dim, |   const int64_t dim, | ||||||
|  |   const int num_max_tokens_per_expert, | ||||||
|  |   bool used_in_ep_low_latency, | ||||||
|   phi::dtype::bfloat16 *out, |   phi::dtype::bfloat16 *out, | ||||||
|   cudaStream_t &stream |   cudaStream_t &stream | ||||||
| ); | ); | ||||||
| @@ -872,6 +1054,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16 | |||||||
| template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>( | template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>( | ||||||
|   const phi::dtype::bfloat16 *x_data, |   const phi::dtype::bfloat16 *x_data, | ||||||
|   const int64_t *expert_idx_per_token, |   const int64_t *expert_idx_per_token, | ||||||
|  |   const int64_t *recv_expert_count, | ||||||
|   const phi::dtype::bfloat16 *shift, |   const phi::dtype::bfloat16 *shift, | ||||||
|   const phi::dtype::bfloat16 *smooth, |   const phi::dtype::bfloat16 *smooth, | ||||||
|   const float* quant_scales, |   const float* quant_scales, | ||||||
| @@ -880,6 +1063,8 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>( | |||||||
|   const float quant_min_bound, |   const float quant_min_bound, | ||||||
|   const int64_t token_num, |   const int64_t token_num, | ||||||
|   const int64_t dim, |   const int64_t dim, | ||||||
|  |   const int num_max_tokens_per_expert, | ||||||
|  |   bool used_in_ep_low_latency, | ||||||
|   int8_t *out, |   int8_t *out, | ||||||
|   cudaStream_t &stream |   cudaStream_t &stream | ||||||
| ); | ); | ||||||
|   | |||||||
| @@ -21,6 +21,7 @@ | |||||||
| template <typename T, typename OutT> | template <typename T, typename OutT> | ||||||
| void MoeFastHardamardWrapper(const T *x_data, | void MoeFastHardamardWrapper(const T *x_data, | ||||||
|                             const int64_t *expert_idx_per_token, |                             const int64_t *expert_idx_per_token, | ||||||
|  |                             const int64_t *recv_expert_count, | ||||||
|                             const T *shift, |                             const T *shift, | ||||||
|                             const T *smooth, |                             const T *smooth, | ||||||
|                             const float* quant_scales, |                             const float* quant_scales, | ||||||
| @@ -29,5 +30,7 @@ void MoeFastHardamardWrapper(const T *x_data, | |||||||
|                             const float quant_min_bound, |                             const float quant_min_bound, | ||||||
|                             const int64_t token_num, |                             const int64_t token_num, | ||||||
|                             const int64_t dim, |                             const int64_t dim, | ||||||
|  |                             const int num_max_tokens_per_expert, | ||||||
|  |                             bool used_in_ep_low_latency, | ||||||
|                             OutT* out, |                             OutT* out, | ||||||
|                             cudaStream_t &stream); |                             cudaStream_t &stream); | ||||||
|   | |||||||
| @@ -240,6 +240,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, | |||||||
|         MoeFastHardamardWrapper<data_t, int8_t>( |         MoeFastHardamardWrapper<data_t, int8_t>( | ||||||
|             act_out_tensor.data<data_t>(), |             act_out_tensor.data<data_t>(), | ||||||
|             expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr, |             expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr, | ||||||
|  |             const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()), | ||||||
|             down_proj_shift, // down_proj_shift->data<T>(), |             down_proj_shift, // down_proj_shift->data<T>(), | ||||||
|             down_proj_smooth, // down_proj_smooth->data<T>(), |             down_proj_smooth, // down_proj_smooth->data<T>(), | ||||||
|             down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr, |             down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr, | ||||||
| @@ -248,6 +249,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, | |||||||
|             -127.0, |             -127.0, | ||||||
|             expanded_active_expert_rows, |             expanded_active_expert_rows, | ||||||
|             inter_size / 2, |             inter_size / 2, | ||||||
|  |             num_max_tokens_per_expert, | ||||||
|  |             used_in_ep_low_latency, | ||||||
|             reinterpret_cast<int8_t *>(int8_act_out->ptr()), |             reinterpret_cast<int8_t *>(int8_act_out->ptr()), | ||||||
|             stream |             stream | ||||||
|         ); |         ); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Yuan Xiaolan
					Yuan Xiaolan