diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 027a33dc0..fd4b28714 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -564,6 +564,7 @@ std::vector NoauxTc( int n_group, int topk_group, int topk, + bool renormalize, float routed_scaling_factor); #ifdef ENABLE_FP8 diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 8256d43cd..97afbb74e 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) { #endif +#ifndef FP8_E4M3_MAX +#define FP8_E4M3_MAX 448.0 +#endif + +#ifndef DISPATCH_FLOAT_FP6_DTYPE +#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \ + switch (pd_dtype) { \ + case phi::DataType::FLOAT32: { \ + using c_type = float; \ + __VA_ARGS__ \ + break; \ + } \ + case phi::DataType::BFLOAT16: { \ + using c_type = phi::dtype::bfloat16; \ + __VA_ARGS__ \ + break; \ + } \ + case phi::DataType::FLOAT16: { \ + using c_type = phi::dtype::float16; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \ + } \ + } +#endif + inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; @@ -563,3 +591,28 @@ inline int GetSMVersion() { return sm_version; } + +__device__ __forceinline__ float warpReduceMax(float value) { + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16)); + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8)); + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4)); + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2)); + value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1)); + return value; +} + +__device__ __forceinline__ float blockReduceMax(float value) { + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + value = warpReduceMax(value); + + if (laneId == 0) warpLevelMaxs[warpId] = value; + __syncthreads(); + + value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) value = warpReduceMax(value); + + return value; +} diff --git a/custom_ops/gpu_ops/noaux_tc.cu b/custom_ops/gpu_ops/noaux_tc.cu index 19a9e380f..7b239b8cf 100644 --- a/custom_ops/gpu_ops/noaux_tc.cu +++ b/custom_ops/gpu_ops/noaux_tc.cu @@ -26,6 +26,7 @@ std::vector NoauxTc(paddle::Tensor& scores, int n_group, int topk_group, int topk, + bool renormalize, float routed_scaling_factor) { auto input_shape = scores_with_bias.shape(); PD_CHECK(input_shape.size() == 2); @@ -48,6 +49,7 @@ std::vector NoauxTc(paddle::Tensor& scores, n_group, topk_group, topk, + renormalize, routed_scaling_factor, stream); @@ -76,6 +78,7 @@ PD_BUILD_STATIC_OP(noaux_tc) .Attrs({"n_group: int", "topk_group: int", "topk:int", + "renormalize: bool", "routed_scaling_factor: float"}) .SetKernelFn(PD_KERNEL(NoauxTc)) .SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape)) diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index e8a3f4508..392dbfe3b 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -25,6 +25,23 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr int32_t BLOCK_SIZE = 512; constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__device__ inline T neg_inf() { + // cuda::std::numeric_limits::infinity() returns `0` for [T=bf16 or fp16] + // so we need to cast from fp32 + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + namespace warp_topk { template @@ -41,10 +58,21 @@ constexpr __host__ __device__ bool isPowerOf2(T v) { } template -__device__ bool is_better_than(T val, T baseline) { +__forceinline__ __device__ bool is_better_than(T val, T baseline) { return (val > baseline && greater) || (val < baseline && !greater); } +template +__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) { + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + } + return res; +} + template int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; @@ -53,7 +81,8 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); } -template +template struct BitonicMerge { // input should be a bitonic sequence, and sort it to be a monotonic sequence __device__ static void merge(T* __restrict__ val_arr, @@ -67,7 +96,15 @@ struct BitonicMerge { int const other_i = i + stride; T& val = val_arr[i]; T& other_val = val_arr[other_i]; - if ((val > other_val && ascending) || (val < other_val && !ascending)) { + bool is_better; + if constexpr (is_stable) { + is_better = is_better_than(val, other_val, idx_arr[i], + idx_arr[other_i]); + } else { + is_better = is_better_than(val, other_val); + } + + if (is_better) { T tmp = val; val = other_val; other_val = tmp; @@ -78,13 +115,14 @@ struct BitonicMerge { } } - BitonicMerge::merge(val_arr, idx_arr); - BitonicMerge::merge(val_arr + arr_len / 2, - idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); } }; -template +template struct BitonicSort { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { @@ -92,15 +130,16 @@ struct BitonicSort { static_assert(size >= 2 * WARP_SIZE); constexpr int arr_len = size / WARP_SIZE; - BitonicSort::sort(val_arr, idx_arr); - BitonicSort::sort(val_arr + arr_len / 2, - idx_arr + arr_len / 2); - BitonicMerge::merge(val_arr, idx_arr); + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); } }; -template -struct BitonicSort<32, ascending, T, idxT> { +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { int const lane = threadIdx.x % WARP_SIZE; @@ -114,19 +153,37 @@ struct BitonicSort<32, ascending, T, idxT> { T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); - if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) { + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { *val_arr = other; *idx_arr = other_idx; } } } - BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr); + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); } }; -template -struct BitonicMerge<32, ascending, T, idxT> { +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { int const lane = threadIdx.x % WARP_SIZE; @@ -136,7 +193,24 @@ struct BitonicMerge<32, ascending, T, idxT> { T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); idxT& idx = *idx_arr; idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); - if (val != other && ((val > other) == (ascending != is_second))) { + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); // for min + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); // for max + } + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + + if (is_better) { val = other; idx = other_idx; } @@ -144,34 +218,42 @@ struct BitonicMerge<32, ascending, T, idxT> { } }; -template +template class WarpSort { -public: + public: __device__ WarpSort(idxT k, T dummy) : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); for (int i = 0; i < max_arr_len_; ++i) { val_arr_[i] = dummy_; + idx_arr_[i] = 0; } } // load and merge k sorted values __device__ void load_sorted(T const* __restrict__ in, - idxT const* __restrict__ in_idx, - idxT start) { + idxT const* __restrict__ in_idx, idxT start) { idxT idx = start + WARP_SIZE - 1 - lane_; for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { if (idx < start + k_) { T t = in[idx]; - if (is_better_than(t, val_arr_[i])) { + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(t, val_arr_[i], in_idx[idx], idx_arr_[i]); + } else { + is_better = is_better_than(t, val_arr_[i]); + } + if (is_better) { val_arr_[i] = t; idx_arr_[i] = in_idx[idx]; } } } - BitonicMerge::merge(val_arr_, idx_arr_); + BitonicMerge::merge( + val_arr_, idx_arr_); } __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { @@ -193,7 +275,7 @@ public: } } -protected: + protected: static constexpr int max_arr_len_ = capacity / WARP_SIZE; T val_arr_[max_arr_len_]; @@ -205,11 +287,11 @@ protected: }; // end class WarpSort -template -class WarpSelect : public WarpSort { -public: +template +class WarpSelect : public WarpSort { + public: __device__ WarpSelect(idxT k, T dummy) - : WarpSort(k, dummy), + : WarpSort(k, dummy), k_th_(dummy), k_th_lane_((k - 1) % WARP_SIZE) { extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; @@ -234,7 +316,13 @@ public: } __device__ void add(T val, idxT idx) { - bool do_add = is_better_than(val, k_th_); + bool do_add; + if constexpr (is_stable) { + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + } else { + do_add = is_better_than(val, k_th_); + } + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); if (mask == 0) { return; @@ -271,37 +359,52 @@ public: __syncthreads(); } -private: + private: __device__ void set_k_th_() { k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) { + k_th_idx_ = + __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } } __device__ void merge_buf_(T val, idxT idx) { - BitonicSort::sort(&val, &idx); + BitonicSort::sort(&val, &idx); T& old = val_arr_[max_arr_len_ - 1]; - if (is_better_than(val, old)) { + + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + } else { + is_better = is_better_than(val, old); + } + + if (is_better) { old = val; idx_arr_[max_arr_len_ - 1] = idx; } - BitonicMerge::merge(val_arr_, idx_arr_); + BitonicMerge::merge( + val_arr_, idx_arr_); set_k_th_(); } - using WarpSort::max_arr_len_; - using WarpSort::val_arr_; - using WarpSort::idx_arr_; - using WarpSort::lane_; - using WarpSort::k_; - using WarpSort::dummy_; + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; T* val_smem_; idxT* idx_smem_; int smem_buf_len_ = 0; T k_th_; + idxT k_th_idx_; int const k_th_lane_; }; // end class WarpSelect } // namespace warp_topk @@ -313,8 +416,8 @@ __device__ void topk_with_k2(T* output, int32_t const lane_id, int const num_experts_per_group) { // Get the top2 per thread - T largest = cuda::std::numeric_limits::min(); - T second_largest = cuda::std::numeric_limits::min(); + T largest = neg_inf(); + T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { @@ -368,8 +471,14 @@ __global__ void topk_with_k2_kernel(T* output, cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif topk_with_k2(output, input, tile, lane_id, num_experts_per_group); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } template @@ -385,6 +494,7 @@ __global__ void group_idx_and_topk_idx_kernel( int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, + bool const renormalize, double routed_scaling_factor) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; @@ -403,19 +513,29 @@ __global__ void group_idx_and_topk_idx_kernel( extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to // store the target topk idx - int32_t* s_topk_idx = reinterpret_cast(smem_buf) + warp_id * topk; + int32_t* s_topk_idx = reinterpret_cast(smem_buf); T* s_topk_value = reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + warp_id * topk; + s_topk_idx += warp_id * topk; - T value = cuda::std::numeric_limits::min(); - T topk_group_value = cuda::std::numeric_limits::min(); + T value = neg_inf(); + T topk_group_value = neg_inf(); int32_t num_equalto_topkth_group; - if ((n_group > topk_group) && (case_id < num_tokens)) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; - if (lane_id < n_group) { + if (lane_id < n_group && + (isfinite(cuda_cast( + group_scores[lane_id])))) // The check is necessary to avoid + // abnormal input + { value = group_scores[lane_id]; } @@ -426,22 +546,23 @@ __global__ void group_idx_and_topk_idx_kernel( __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { - value = cuda::std::numeric_limits::min(); + value = neg_inf(); } pre_count_equal_to_top_value = count_equal_to_top_value; count_equal_to_top_value = __popc(__ballot_sync( - FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + FULL_WARP_MASK, (value == neg_inf()))); } num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; } __syncthreads(); - warp_topk::WarpSelect - queue((int32_t)topk, cuda::std::numeric_limits::min()); + warp_topk::WarpSelect + queue((int32_t)topk, neg_inf()); int count_equalto_topkth_group = 0; - bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits::min()); - if (case_id < num_tokens) { + bool if_proceed_next_topk = (topk_group_value != neg_inf()); + if (case_id < num_tokens && if_proceed_next_topk) { for (int i_group = 0; i_group < n_group; i_group++) { if ((group_scores[i_group] > topk_group_value) || ((group_scores[i_group] == topk_group_value) && @@ -449,9 +570,11 @@ __global__ void group_idx_and_topk_idx_kernel( int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { - T candidates = i < num_experts_per_group - ? scores_with_bias[offset + i] - : cuda::std::numeric_limits::min(); + T candidates = + (i < num_experts_per_group) && isfinite(cuda_cast( + scores_with_bias[offset + i])) + ? scores_with_bias[offset + i] + : neg_inf(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { @@ -469,7 +592,7 @@ __global__ void group_idx_and_topk_idx_kernel( // Load the valid score value // Calculate the summation float topk_sum = 1e-20; - if (case_id < num_tokens) { + if (case_id < num_tokens && if_proceed_next_topk) { for (int i = lane_id; i < warp_topk::round_up_to_multiple_of(topk); i += WARP_SIZE) { @@ -478,33 +601,45 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { s_topk_value[i] = value; } - topk_sum += reduce(tile, value, cg::plus()); + topk_sum += reduce(tile, cuda_cast(value), cg::plus()); } } __syncthreads(); - if (case_id < num_tokens) { + + if (case_id < num_tokens && if_proceed_next_topk) { for (int i = lane_id; i < num_experts; i += WARP_SIZE) { scores[i] = 0; } } - __threadfence(); - __syncthreads(); + __syncwarp(); if (case_id < num_tokens) { - for (int i = lane_id; i < topk; i += WARP_SIZE) { - float value = s_topk_value[i] / topk_sum * routed_scaling_factor; - scores[s_topk_idx[i]] = value; - if (if_proceed_next_topk) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + scores[s_topk_idx[i]] = value; topk_indices[i] = s_topk_idx[i]; - topk_values[i] = static_cast(value); + topk_values[i] = cuda_cast(value); } - else { + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { topk_indices[i] = i; - topk_values[i] = static_cast(1.0f / topk); + topk_values[i] = cuda_cast(1.0f / topk); } } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } template @@ -518,17 +653,24 @@ void invokeNoAuxTc(T* scores, int64_t const n_group, int64_t const topk_group, int64_t const topk, + bool const renormalize, double const routed_scaling_factor, cudaStream_t const stream) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; - topk_with_k2_kernel<<>>( - group_scores, - scores_with_bias, - num_tokens, - num_cases, - n_group, - num_experts / n_group); + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, + num_tokens, num_cases, n_group, num_experts / n_group); int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; @@ -536,21 +678,19 @@ void invokeNoAuxTc(T* scores, warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, topk); - group_idx_and_topk_idx_kernel<<>>(scores, - group_scores, - topk_values, - topk_indices, - scores_with_bias, - num_tokens, - n_group, - topk_group, - topk, - num_experts, - num_experts / n_group, - routed_scaling_factor); + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, scores_with_bias, num_tokens, + n_group, topk_group, topk, num_experts, + num_experts / n_group, renormalize, routed_scaling_factor); } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ @@ -564,6 +704,7 @@ void invokeNoAuxTc(T* scores, int64_t const n_group, \ int64_t const topk_group, \ int64_t const topk, \ + bool const renormalize, \ double const routed_scaling_factor, \ cudaStream_t const stream); diff --git a/custom_ops/gpu_ops/quantization/common.cu b/custom_ops/gpu_ops/quantization/common.cu index 7d8388f99..c0e8f48ee 100644 --- a/custom_ops/gpu_ops/quantization/common.cu +++ b/custom_ops/gpu_ops/quantization/common.cu @@ -3,6 +3,158 @@ #include "quantization/common.cuh" +// adapted from: https://github.com/sgl-project/sglang/blob/v0.5.2rc2/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu + +// --------------------------------------------------------------------------- +// 1. Warp‑local, no shared memory +// • One warp handles one token. +// • Eight tokens per 256‑thread CTA. +// --------------------------------------------------------------------------- +template +__global__ void per_token_quant_fp8_kernel( + const T* __restrict__ input, + DST_DTYPE* __restrict__ output_q, + float* __restrict__ output_s, + const float scale_ub, + const int64_t hidden_size, + const int64_t num_tokens) { + const int warp_id = threadIdx.x / WARP_SIZE; // 0‑7 (8 warps) + const int lane_id = threadIdx.x & (WARP_SIZE - 1); // 0‑31 + const int token_id = blockIdx.x * kTokensPerCTA + warp_id; + if (token_id >= num_tokens) return; + + // Global tensors for this token + const T* token_input = input + token_id * hidden_size; + DST_DTYPE* token_output = output_q + token_id * hidden_size; + float* token_scale = output_s + token_id; + + // + // Pass-1: Perform a warp reduce to find the max_value of a token's hidden_size + // + float max_value = 0.f; + using vec_t = AlignedVector; + const int32_t num_vec_elems = hidden_size / kVecSize; + + for (int32_t i = lane_id; i < num_vec_elems; i += WARP_SIZE) { + vec_t input_vec; + Load(token_input + i * kVecSize, &input_vec); + +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + max_value = fmaxf(max_value, fabsf(static_cast(input_vec[j]))); + } + } + + float warp_max = warpReduceMax(max_value); + if (scale_ub > 0){ + warp_max = fminf(warp_max, scale_ub); + } + float scale; + scale = warp_max / FP8_E4M3_MAX; + // Broadcast scale + if (lane_id == 0) { + token_scale[0] = scale; + } + float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale; + + // + // Pass-2: quantize and write back + // + for (int i = lane_id; i < num_vec_elems; i += WARP_SIZE) { + vec_t input_vec; + Load(token_input + i * kVecSize, &input_vec); + DST_DTYPE output_arr[kVecSize]; +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + float val = static_cast(input_vec[j]) * scale_inv; + val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX); + output_arr[j] = static_cast(val); + } + if constexpr (kVecSize == 16) { + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + } else { + // Use element-wise copy for vector size 8 to ensure correctness + for (int k = 0; k < kVecSize; ++k) { + token_output[i * kVecSize + k] = output_arr[k]; + } + } + } +} + +// --------------------------------------------------------------------------- +// 2. Baseline kernel (1 token / CTA, CUB block reduce) +// --------------------------------------------------------------------------- +template +__global__ void per_token_quant_fp8_small_batch_kernel( + const T* __restrict__ input, + DST_DTYPE* __restrict__ output_q, + float* __restrict__ output_s, + const float scale_ub, + const int64_t hidden_size, + const int64_t num_tokens) { + const int token_idx = blockIdx.x; + if (token_idx >= num_tokens) return; + + const int tid = threadIdx.x; + const int block_dim = blockDim.x; + + const T* token_input = input + token_idx * hidden_size; + DST_DTYPE* token_output = output_q + token_idx * hidden_size; + + float max_value = 0.0f; + + // Use template parameter for vector size + using vec_t = AlignedVector; + const int32_t num_vec_elems = hidden_size / kVecSize; + + // Find max using vectorized loads + for (int32_t i = tid; i < num_vec_elems; i += block_dim) { + vec_t input_vec; + Load(token_input + i * kVecSize, &input_vec); + +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + float val = static_cast(input_vec[j]); + max_value = fmaxf(max_value, fabsf(val)); + } + } + + max_value = blockReduceMax(max_value); + if (scale_ub > 0){ + max_value = fminf(max_value, scale_ub); + } + __shared__ float scale; + if (tid == 0) { + scale = max_value / FP8_E4M3_MAX; + output_s[token_idx] = scale; + } + __syncthreads(); + + const float scale_inv = 1.0f / scale; + + // Quantize using vectorized loads + for (int32_t i = tid; i < num_vec_elems; i += block_dim) { + vec_t input_vec; + Load(token_input + i * kVecSize, &input_vec); + + DST_DTYPE output_arr[kVecSize]; +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX); + output_arr[j] = static_cast(val); + } + + if constexpr (kVecSize == 16) { + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + } else { + // Use element-wise copy for vector size 8 to ensure correctness + for (int k = 0; k < kVecSize; ++k) { + token_output[i * kVecSize + k] = output_arr[k]; + } + } + } +} + namespace fastdeploy { template @@ -179,39 +331,78 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d] auto rank = input.dims().size(); int const hidden_size = input.dims()[rank - 1]; int const num_tokens = input.numel() / hidden_size; + cudaStream_t stream = input.stream(); + + if (hidden_size % 8 == 0){ + int device = 0; + cudaGetDevice(&device); + int sm_count = 0; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device); + const int TOKENS_PER_CTA = 8; + const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); + const bool use_vec16 = (hidden_size % 16 == 0); + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, { + if (use_warp_kernel) { + // -------- warp‑local --------------------------------------------------- + constexpr int THREADS = TOKENS_PER_CTA * WARP_SIZE; // 256 + dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); + dim3 block(THREADS); + + if (use_vec16) { + per_token_quant_fp8_kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.data()), + reinterpret_cast(scales.data()), + scale_ub, + hidden_size, + num_tokens); + } else { + per_token_quant_fp8_kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.data()), + reinterpret_cast(scales.data()), + scale_ub, + hidden_size, + num_tokens); + } + } else { + // -------- baseline ----------------------------------------------------- + constexpr int THREADS = 256; + dim3 grid(num_tokens); + dim3 block(THREADS); + + if (use_vec16) { + per_token_quant_fp8_small_batch_kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.data()), + reinterpret_cast(scales.data()), + scale_ub, + hidden_size, + num_tokens); + } else { + per_token_quant_fp8_small_batch_kernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast<__nv_fp8_e4m3*>(out.data()), + reinterpret_cast(scales.data()), + scale_ub, + hidden_size, + num_tokens); + } + } + }); + return; + } + dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); - cudaStream_t stream = input.stream(); + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, { + fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel + <<>>(out.data(), scales.data(), + input.data(), scale_ub, + hidden_size); + }); - switch (input.dtype()) { - case paddle::DataType::FLOAT32: { - using scalar_t = float; - fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel - <<>>(out.data(), scales.data(), - input.data(), scale_ub, - hidden_size); - break; - } - case paddle::DataType::FLOAT16: { - using scalar_t = phi::dtype::float16; - fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel - <<>>(out.data(), scales.data(), - input.data(), scale_ub, - hidden_size); - break; - } - case paddle::DataType::BFLOAT16: { - using scalar_t = phi::dtype::bfloat16; - fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel - <<>>(out.data(), scales.data(), - input.data(), scale_ub, - hidden_size); - break; - } - default: - PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); - } } PD_BUILD_STATIC_OP(static_scaled_fp8_quant) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index f1fb9fbf8..df716b00e 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -300,6 +300,7 @@ class EPRunner: layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index aab5960ee..8fd3aaddf 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -227,13 +227,14 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): """ gate_out = gate(x.cast("float32")) if layer.topk_method == "noaux_tc": - gate_out, _, _ = get_moe_scores( + gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, layer.topk_group, layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) ( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 386dbe75d..2ac33abbc 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -490,6 +490,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) else: topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index 4346063b7..0223c5029 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -262,6 +262,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index bf3baaa91..684549aeb 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -32,6 +32,7 @@ try: except ImportError: pass from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant class TritonWeightOnlyMoEMethod(QuantMethodBase): @@ -258,8 +259,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) - topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False) else: topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, @@ -327,6 +328,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=False, use_int8_w8a16=True, + per_channel_quant=False, even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, ) @@ -379,6 +381,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=False, use_int8_w8a16=True, + per_channel_quant=False, even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, ) @@ -390,6 +393,377 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): return out +class Wfp8Afp8MoEMethod(QuantMethodBase): + """ + Use Triton Group Gemm to compute Fused wfp8afp8 Quant MoE. + """ + + def __init__(self, quant_config): + """ + Triton Group Gemm to compute Fused MoE. + """ + self.quant_config = quant_config + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] + self.added_scale_attrs = [ + "up_gate_proj_weight_scale", + "down_proj_weight_scale", + ] + + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None: + """process_prequanted_weights""" + + raise NotImplementedError + + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + """ + Triton MoE create weight process. + """ + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + self.down_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ] + self.up_gate_proj_scale_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + 1, + ] + self.down_proj_scale_shape = [ + layer.num_local_experts, + layer.hidden_size, + 1, + ] + if self.quant_config.is_checkpoint_bf16: + layer.up_gate_proj_weight = layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + layer.down_proj_weight = layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + set_weight_attrs( + layer.up_gate_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), + }, + ) + else: + self.weight_dtype = paddle.float8_e4m3fn + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + up_gate_proj_scale_name = self.added_scale_attrs[0] + down_proj_scale_name = self.added_scale_attrs[1] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + up_gate_proj_scale_name, + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_scale_name, + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def process_weights_after_loading(self, layer): + """ """ + if not self.quant_config.is_checkpoint_bf16: + return + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + layer.up_gate_proj_weight.tensor_track = None + else: + weight_type = "down" + layer.down_proj_weight.tensor_track = None + + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape + weight_dtype = paddle.float8_e4m3fn + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + scale_dtype = "float32" + + # 2.crate tmp tensor + + weight = paddle.empty(shape=weight_shape, dtype=weight_dtype) + scale = paddle.empty(shape=scale_shape, dtype=scale_dtype) + + # 3.quantize weight + from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8 + + for expert_id in range(layer.num_experts): + weight_quant, scale[expert_id] = per_token_cast_to_fp8( + getattr(layer, weight_name)[expert_id].transpose([1, 0]).contiguous(), + ) + weight[expert_id].copy_(weight_quant, False) + getattr(layer, weight_name).value().get_tensor()._clear() + + # create weight + setattr( + layer, + weight_name, + layer.create_parameter( + shape=weight_shape, + dtype=weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # create scale + setattr( + layer, + scale_name, + layer.create_parameter( + shape=scale_shape, + dtype=scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, weight_name).copy_(weight, False) + getattr(layer, scale_name).copy_(scale, False) + + def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights): + """ + check layer is valid for this method + """ + assert up_gate_proj_weights[0].shape == [ + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + assert down_proj_weights[0].shape == [ + layer.hidden_size, + layer.moe_intermediate_size, + ] + + def apply( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Triton compute Fused MoE. + """ + gate_out = gate(x.cast("float32")) + token_num = x.shape[0] + top_k = layer.top_k + num_local_experts = layer.num_local_experts + moe_intermediate_size = layer.moe_intermediate_size + hidden_size = layer.hidden_size + E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape + + if layer.topk_method == "noaux_tc": + gate_out, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + ) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) + + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4, + } + if token_num <= E: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + } + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, config["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) + * ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), + ) + + up_gate_proj_out = paddle.empty( + [token_num * top_k, moe_intermediate_size * 2], + dtype=x.dtype, + ) + + from .triton_moe_kernels import fused_moe_kernel_paddle + + x_q, x_scale = scaled_fp8_quant(x, use_per_token_if_dynamic=True) + + fused_moe_kernel_paddle[grid]( + x_q, + layer.up_gate_proj_weight, + up_gate_proj_out, + x_scale, + layer.up_gate_proj_weight_scale, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_possible_num_post_padded, + token_num * top_k, + N=moe_intermediate_size * 2, + K=hidden_size, + stride_am=x_q.strides[0], + stride_ak=x_q.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[2], + stride_bn=layer.up_gate_proj_weight.strides[1], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], + # + stride_asm=x_scale.strides[0], + stride_ask=x_scale.strides[1], + stride_bse=layer.up_gate_proj_weight_scale.strides[0], + stride_bsk=layer.up_gate_proj_weight_scale.strides[2], + stride_bsn=layer.up_gate_proj_weight_scale.strides[1], + group_n=-1, + group_k=-1, + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type_enum=1, + use_fp8_w8a8=True, + use_int8_w8a16=False, + per_channel_quant=True, + even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, + ) + + down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) + + down_proj_out = paddle.empty( + (token_num * top_k, hidden_size), + dtype=x.dtype, + ) + + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) + * ceil_div(hidden_size, config["BLOCK_SIZE_N"]), + ) + + x_q, x_scale = scaled_fp8_quant(down_proj_input, use_per_token_if_dynamic=True) + + fused_moe_kernel_paddle[grid]( + x_q, + layer.down_proj_weight, + down_proj_out, + x_scale, + layer.down_proj_weight_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_possible_num_post_padded, + token_num * top_k, + N=hidden_size, + K=moe_intermediate_size, + stride_am=x_q.strides[0], + stride_ak=x_scale.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[2], + stride_bn=layer.down_proj_weight.strides[1], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], + stride_asm=x_scale.strides[0], + stride_ask=x_scale.strides[1], + stride_bse=layer.down_proj_weight_scale.strides[0], + stride_bsk=layer.down_proj_weight_scale.strides[2], + stride_bsn=layer.down_proj_weight_scale.strides[1], + group_n=-1, + group_k=-1, + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, + top_k=1, + compute_type_enum=1, + use_fp8_w8a8=True, + use_int8_w8a16=False, + per_channel_quant=True, + even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, + ) + + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) + + if layer.reduce_results and layer.tp_size > 1: + tensor_model_parallel_all_reduce(out) + + return out + + class TensorWiseFP8MoEMethod(QuantMethodBase): """ Use Triton Group Gemm to compute Fused MoE. @@ -524,6 +898,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) else: @@ -607,6 +982,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, + per_channel_quant=False, even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0, ) @@ -676,6 +1052,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, + per_channel_quant=False, even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0, ) @@ -945,6 +1322,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): layer.top_k, layer.routed_scaling_factor, layer.gate_correction_bias, + getattr(layer, "renormalize", True), ) else: topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( @@ -1021,6 +1399,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, + per_channel_quant=False, even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, ) @@ -1074,6 +1453,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, + per_channel_quant=False, even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, ) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 76c962069..f45039b76 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -66,6 +66,7 @@ def get_moe_scores( top_k, routed_scaling_factor, e_score_correction_bias, + renormalize: bool = False, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. @@ -79,6 +80,7 @@ def get_moe_scores( n_group if n_group > 0 else 1, topk_group if topk_group > 0 else 1, top_k, + renormalize, routed_scaling_factor, ) return scores, topk_values, topk_idx @@ -93,6 +95,7 @@ class FusedMoE(nn.Layer): self, fd_config, reduce_results: bool = True, + renormalize: bool = False, moe_intermediate_size: int = -1, num_experts: int = -1, expert_id_offset: int = 0, @@ -119,6 +122,7 @@ class FusedMoE(nn.Layer): self.fd_config = fd_config self.layer_idx = layer_idx self.reduce_results = reduce_results + self.renormalize = renormalize self.tp_rank = fd_config.parallel_config.tensor_parallel_rank self.tp_size = fd_config.parallel_config.tensor_parallel_size self.ep_size = fd_config.parallel_config.expert_parallel_size diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index 61a7024b1..cb2e56ea0 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -59,6 +59,7 @@ def fused_moe_kernel_paddle( compute_type_enum: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, even_Ks: tl.constexpr, ): """ @@ -121,6 +122,13 @@ def fused_moe_kernel_paddle( a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + # channel-wise + elif per_channel_quant: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] else: # (Zkk): every expert has one activation scale and weight scale. a_scale = tl.load(a_scale_ptr + off_experts) diff --git a/fastdeploy/model_executor/layers/quantization/wfp8afp8.py b/fastdeploy/model_executor/layers/quantization/wfp8afp8.py index 93f96bf54..7a7e024f5 100644 --- a/fastdeploy/model_executor/layers/quantization/wfp8afp8.py +++ b/fastdeploy/model_executor/layers/quantization/wfp8afp8.py @@ -23,6 +23,7 @@ from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ) +from fastdeploy.model_executor.layers.moe import FusedMoE from fastdeploy.model_executor.layers.quantization.ops import ( cutlass_scaled_mm, scaled_fp8_quant, @@ -66,7 +67,14 @@ class WFP8AFP8Config(QuantConfigBase): def get_quant_method(self, layer) -> Optional[QuantMethodBase]: """ """ - return WFP8AFP8LinearMethod(self) + if isinstance(layer, FusedMoE): + from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import ( + Wfp8Afp8MoEMethod, + ) + + return Wfp8Afp8MoEMethod(self) + else: + return WFP8AFP8LinearMethod(self) class WFP8AFP8LinearMethod(QuantMethodBase): diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index beb348779..bbfc8f36c 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -116,6 +116,7 @@ class DeepSeekV3MoE(nn.Layer): super().__init__() self.tp_size = fd_config.parallel_config.tensor_parallel_size + self.norm_topk_prob = fd_config.model_config.norm_topk_prob weight_key_map = { "gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias", @@ -145,6 +146,7 @@ class DeepSeekV3MoE(nn.Layer): self.experts = FusedMoE( fd_config=fd_config, reduce_results=False, + renormalize=self.norm_topk_prob, moe_intermediate_size=fd_config.model_config.moe_intermediate_size, num_experts=fd_config.model_config.n_routed_experts, top_k=fd_config.model_config.num_experts_per_tok, diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index ea56e9a47..b11888a85 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -109,6 +109,8 @@ class Glm4Moe(nn.Layer): self.n_routed_experts: int = fd_config.model_config.n_routed_experts self.n_shared_experts: int = fd_config.model_config.n_shared_experts + self.norm_topk_prob = fd_config.model_config.norm_topk_prob + weight_key_map = { "gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias", "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", @@ -133,6 +135,7 @@ class Glm4Moe(nn.Layer): self.experts = FusedMoE( fd_config, reduce_results=False, + renormalize=self.norm_topk_prob, moe_intermediate_size=fd_config.model_config.moe_intermediate_size, num_experts=fd_config.model_config.n_routed_experts, top_k=fd_config.model_config.num_experts_per_tok, diff --git a/tests/e2e/test_fake_Glm45_AIR_serving.py b/tests/e2e/test_fake_Glm45_AIR_serving.py index ff0a3f5be..58d224d91 100644 --- a/tests/e2e/test_fake_Glm45_AIR_serving.py +++ b/tests/e2e/test_fake_Glm45_AIR_serving.py @@ -115,17 +115,16 @@ def setup_and_run_server(): "--max-model-len", "32768", "--max-num-seqs", - "32", + "1", "--graph-optimization-config", '{"use_cudagraph":true}', "--load_choices", "default_v1", "--lm_head-fp32", "--quantization", - '{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}', + "wfp8afp8", ] env = os.environ.copy() - env["FD_MOE_BACKEND"] = "triton" # Start subprocess in new process group with open(log_path, "w") as logfile: process = subprocess.Popen( diff --git a/tests/operators/test_noaux_tc.py b/tests/operators/test_noaux_tc.py index 06e065673..b0fe6d900 100644 --- a/tests/operators/test_noaux_tc.py +++ b/tests/operators/test_noaux_tc.py @@ -15,6 +15,7 @@ class TestMoeRouting(unittest.TestCase): self.topk_group = 4 self.top_k = 8 self.routed_scaling_factor = 1.5 + self.renormalize = True def node_limit_routing(self, gate_probs): """将所有专家分组, 只在topk_group个group内选择专家""" @@ -64,6 +65,7 @@ class TestMoeRouting(unittest.TestCase): self.topk_group, self.top_k, self.routed_scaling_factor, + self.renormalize, ) ref_topk_values, ref_topk_idx = self.ref_moe_routing()