Fix noaux_tc cuda Error 700 in CUDAGraph and Add wfp8apf8 moe quant method (#4115)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* improve per_token_quant_fp8 performance

* support moe wfp8apf8

* check glm test

* fix noaux_tc op in cudagraph, support noaux_tc return the correct

* check

* check inf and overwrite score in noaux_tc

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chen
2025-09-22 21:27:37 +08:00
committed by GitHub
parent 6b47773bd6
commit f38b174a75
17 changed files with 924 additions and 125 deletions

View File

@@ -564,6 +564,7 @@ std::vector<paddle::Tensor> NoauxTc(
int n_group,
int topk_group,
int topk,
bool renormalize,
float routed_scaling_factor);
#ifdef ENABLE_FP8

View File

@@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) {
#endif
#ifndef FP8_E4M3_MAX
#define FP8_E4M3_MAX 448.0
#endif
#ifndef DISPATCH_FLOAT_FP6_DTYPE
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
switch (pd_dtype) { \
case phi::DataType::FLOAT32: { \
using c_type = float; \
__VA_ARGS__ \
break; \
} \
case phi::DataType::BFLOAT16: { \
using c_type = phi::dtype::bfloat16; \
__VA_ARGS__ \
break; \
} \
case phi::DataType::FLOAT16: { \
using c_type = phi::dtype::float16; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
} \
}
#endif
inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1)
return num;
@@ -563,3 +591,28 @@ inline int GetSMVersion() {
return sm_version;
}
__device__ __forceinline__ float warpReduceMax(float value) {
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
return value;
}
__device__ __forceinline__ float blockReduceMax(float value) {
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
value = warpReduceMax(value);
if (laneId == 0) warpLevelMaxs[warpId] = value;
__syncthreads();
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) value = warpReduceMax(value);
return value;
}

View File

@@ -26,6 +26,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
int n_group,
int topk_group,
int topk,
bool renormalize,
float routed_scaling_factor) {
auto input_shape = scores_with_bias.shape();
PD_CHECK(input_shape.size() == 2);
@@ -48,6 +49,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
n_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
stream);
@@ -76,6 +78,7 @@ PD_BUILD_STATIC_OP(noaux_tc)
.Attrs({"n_group: int",
"topk_group: int",
"topk:int",
"renormalize: bool",
"routed_scaling_factor: float"})
.SetKernelFn(PD_KERNEL(NoauxTc))
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))

View File

@@ -25,6 +25,23 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t BLOCK_SIZE = 512;
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val) {
return val;
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}
template <typename T>
__device__ inline T neg_inf() {
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
// so we need to cast from fp32
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
}
namespace warp_topk {
template <int size, typename T>
@@ -41,10 +58,21 @@ constexpr __host__ __device__ bool isPowerOf2(T v) {
}
template <bool greater, typename T>
__device__ bool is_better_than(T val, T baseline) {
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
return (val > baseline && greater) || (val < baseline && !greater);
}
template <bool greater, typename T, typename idxT>
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
idxT baseline_index) {
bool res = (val > baseline && greater) || (val < baseline && !greater);
if (val == baseline) {
res = (index < baseline_index && greater) ||
(index < baseline_index && !greater);
}
return res;
}
template <typename T, typename idxT>
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
@@ -53,7 +81,8 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}
template <int size, bool ascending, typename T, typename idxT>
template <int size, bool ascending, bool reverse, typename T, typename idxT,
bool is_stable>
struct BitonicMerge {
// input should be a bitonic sequence, and sort it to be a monotonic sequence
__device__ static void merge(T* __restrict__ val_arr,
@@ -67,7 +96,15 @@ struct BitonicMerge {
int const other_i = i + stride;
T& val = val_arr[i];
T& other_val = val_arr[other_i];
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
bool is_better;
if constexpr (is_stable) {
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
idx_arr[other_i]);
} else {
is_better = is_better_than<ascending>(val, other_val);
}
if (is_better) {
T tmp = val;
val = other_val;
other_val = tmp;
@@ -78,13 +115,14 @@ struct BitonicMerge {
}
}
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
idx_arr + arr_len / 2);
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
val_arr, idx_arr);
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
val_arr + arr_len / 2, idx_arr + arr_len / 2);
}
};
template <int size, bool ascending, typename T, typename idxT>
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
struct BitonicSort {
__device__ static void sort(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
@@ -92,15 +130,16 @@ struct BitonicSort {
static_assert(size >= 2 * WARP_SIZE);
constexpr int arr_len = size / WARP_SIZE;
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
idx_arr + arr_len / 2);
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
val_arr + arr_len / 2, idx_arr + arr_len / 2);
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
val_arr, idx_arr);
}
};
template <bool ascending, typename T, typename idxT>
struct BitonicSort<32, ascending, T, idxT> {
template <bool ascending, typename T, typename idxT, bool is_stable>
struct BitonicSort<32, ascending, T, idxT, is_stable> {
__device__ static void sort(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
int const lane = threadIdx.x % WARP_SIZE;
@@ -114,19 +153,37 @@ struct BitonicSort<32, ascending, T, idxT> {
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
bool is_better;
if constexpr (is_stable) {
if constexpr (ascending) {
is_better = ((*val_arr > other) ||
((*val_arr == other) && (*idx_arr < other_idx))) !=
(reverse != is_second);
} else {
is_better = ((*val_arr > other) ||
((*val_arr == other) && (*idx_arr > other_idx))) !=
(reverse != is_second);
}
} else {
is_better = (*val_arr != other &&
(*val_arr > other) != (reverse != is_second));
}
if (is_better) {
*val_arr = other;
*idx_arr = other_idx;
}
}
}
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
idx_arr);
}
};
template <bool ascending, typename T, typename idxT>
struct BitonicMerge<32, ascending, T, idxT> {
template <bool ascending, bool reverse, typename T, typename idxT,
bool is_stable>
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
__device__ static void merge(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
int const lane = threadIdx.x % WARP_SIZE;
@@ -136,7 +193,24 @@ struct BitonicMerge<32, ascending, T, idxT> {
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
idxT& idx = *idx_arr;
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
if (val != other && ((val > other) == (ascending != is_second))) {
bool is_better;
if constexpr (is_stable) {
if constexpr (ascending) {
is_better = ((*val_arr > other) ||
((*val_arr == other) && (*idx_arr < other_idx))) ==
(reverse != is_second); // for min
} else {
is_better = ((*val_arr > other) ||
((*val_arr == other) && (*idx_arr > other_idx))) ==
(reverse != is_second); // for max
}
} else {
is_better =
(val != other && ((val > other) == (ascending != is_second)));
}
if (is_better) {
val = other;
idx = other_idx;
}
@@ -144,34 +218,42 @@ struct BitonicMerge<32, ascending, T, idxT> {
}
};
template <int capacity, bool greater, typename T, typename idxT>
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
class WarpSort {
public:
public:
__device__ WarpSort(idxT k, T dummy)
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
for (int i = 0; i < max_arr_len_; ++i) {
val_arr_[i] = dummy_;
idx_arr_[i] = 0;
}
}
// load and merge k sorted values
__device__ void load_sorted(T const* __restrict__ in,
idxT const* __restrict__ in_idx,
idxT start) {
idxT const* __restrict__ in_idx, idxT start) {
idxT idx = start + WARP_SIZE - 1 - lane_;
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
if (idx < start + k_) {
T t = in[idx];
if (is_better_than<greater>(t, val_arr_[i])) {
bool is_better;
if constexpr (is_stable) {
is_better =
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
} else {
is_better = is_better_than<greater>(t, val_arr_[i]);
}
if (is_better) {
val_arr_[i] = t;
idx_arr_[i] = in_idx[idx];
}
}
}
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
val_arr_, idx_arr_);
}
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
@@ -193,7 +275,7 @@ public:
}
}
protected:
protected:
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
T val_arr_[max_arr_len_];
@@ -205,11 +287,11 @@ protected:
}; // end class WarpSort
template <int capacity, bool greater, typename T, typename idxT>
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
public:
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
public:
__device__ WarpSelect(idxT k, T dummy)
: WarpSort<capacity, greater, T, idxT>(k, dummy),
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
k_th_(dummy),
k_th_lane_((k - 1) % WARP_SIZE) {
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
@@ -234,7 +316,13 @@ public:
}
__device__ void add(T val, idxT idx) {
bool do_add = is_better_than<greater>(val, k_th_);
bool do_add;
if constexpr (is_stable) {
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
} else {
do_add = is_better_than<greater>(val, k_th_);
}
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
if (mask == 0) {
return;
@@ -271,37 +359,52 @@ public:
__syncthreads();
}
private:
private:
__device__ void set_k_th_() {
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
if constexpr (is_stable) {
k_th_idx_ =
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
}
}
__device__ void merge_buf_(T val, idxT idx) {
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
T& old = val_arr_[max_arr_len_ - 1];
if (is_better_than<greater>(val, old)) {
bool is_better;
if constexpr (is_stable) {
is_better =
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
} else {
is_better = is_better_than<greater>(val, old);
}
if (is_better) {
old = val;
idx_arr_[max_arr_len_ - 1] = idx;
}
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
val_arr_, idx_arr_);
set_k_th_();
}
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
using WarpSort<capacity, greater, T, idxT>::val_arr_;
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
using WarpSort<capacity, greater, T, idxT>::lane_;
using WarpSort<capacity, greater, T, idxT>::k_;
using WarpSort<capacity, greater, T, idxT>::dummy_;
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
T* val_smem_;
idxT* idx_smem_;
int smem_buf_len_ = 0;
T k_th_;
idxT k_th_idx_;
int const k_th_lane_;
}; // end class WarpSelect
} // namespace warp_topk
@@ -313,8 +416,8 @@ __device__ void topk_with_k2(T* output,
int32_t const lane_id,
int const num_experts_per_group) {
// Get the top2 per thread
T largest = cuda::std::numeric_limits<T>::min();
T second_largest = cuda::std::numeric_limits<T>::min();
T largest = neg_inf<T>();
T second_largest = neg_inf<T>();
if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
@@ -368,8 +471,14 @@ __global__ void topk_with_k2_kernel(T* output,
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, typename IdxT>
@@ -385,6 +494,7 @@ __global__ void group_idx_and_topk_idx_kernel(
int64_t const topk,
int64_t const num_experts,
int64_t const num_experts_per_group,
bool const renormalize,
double routed_scaling_factor) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
@@ -403,19 +513,29 @@ __global__ void group_idx_and_topk_idx_kernel(
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
// store the target topk idx
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
T* s_topk_value =
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
warp_id * topk;
s_topk_idx += warp_id * topk;
T value = cuda::std::numeric_limits<T>::min();
T topk_group_value = cuda::std::numeric_limits<T>::min();
T value = neg_inf<T>();
T topk_group_value = neg_inf<T>();
int32_t num_equalto_topkth_group;
if ((n_group > topk_group) && (case_id < num_tokens)) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
// acqbulk because it's ptr arithmetic
#endif
if (case_id < num_tokens) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group) {
if (lane_id < n_group &&
(isfinite(cuda_cast<float, T>(
group_scores[lane_id])))) // The check is necessary to avoid
// abnormal input
{
value = group_scores[lane_id];
}
@@ -426,22 +546,23 @@ __global__ void group_idx_and_topk_idx_kernel(
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = cuda::std::numeric_limits<T>::min();
value = neg_inf<T>();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
FULL_WARP_MASK, (value == neg_inf<T>())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
/* is_stable */ true>
queue((int32_t)topk, neg_inf<T>());
int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
if (case_id < num_tokens) {
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) &&
@@ -449,9 +570,11 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates = i < num_experts_per_group
? scores_with_bias[offset + i]
: cuda::std::numeric_limits<T>::min();
T candidates =
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
scores_with_bias[offset + i]))
? scores_with_bias[offset + i]
: neg_inf<T>();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
@@ -469,7 +592,7 @@ __global__ void group_idx_and_topk_idx_kernel(
// Load the valid score value
// Calculate the summation
float topk_sum = 1e-20;
if (case_id < num_tokens) {
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
@@ -478,33 +601,45 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, value, cg::plus<float>());
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}
__syncthreads();
if (case_id < num_tokens) {
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__threadfence();
__syncthreads();
__syncwarp();
if (case_id < num_tokens) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
scores[s_topk_idx[i]] = value;
if (if_proceed_next_topk) {
if (if_proceed_next_topk) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value;
if (renormalize) {
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
routed_scaling_factor;
} else {
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
scores[s_topk_idx[i]] = value;
topk_indices[i] = s_topk_idx[i];
topk_values[i] = static_cast<T>(value);
topk_values[i] = cuda_cast<T, float>(value);
}
else {
} else {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
topk_indices[i] = i;
topk_values[i] = static_cast<float>(1.0f / topk);
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
// default result.
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, typename IdxT>
@@ -518,17 +653,24 @@ void invokeNoAuxTc(T* scores,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
bool const renormalize,
double const routed_scaling_factor,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores,
scores_with_bias,
num_tokens,
num_cases,
n_group,
num_experts / n_group);
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks;
config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
@@ -536,21 +678,19 @@ void invokeNoAuxTc(T* scores,
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
BLOCK_SIZE,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
routed_scaling_factor);
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = dynamic_smem_in_bytes;
config.stream = stream;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
@@ -564,6 +704,7 @@ void invokeNoAuxTc(T* scores,
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
bool const renormalize, \
double const routed_scaling_factor, \
cudaStream_t const stream);

View File

@@ -3,6 +3,158 @@
#include "quantization/common.cuh"
// adapted from: https://github.com/sgl-project/sglang/blob/v0.5.2rc2/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
// ---------------------------------------------------------------------------
// 1. Warplocal, no shared memory
// • One warp handles one token.
// • Eight tokens per 256thread CTA.
// ---------------------------------------------------------------------------
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
__global__ void per_token_quant_fp8_kernel(
const T* __restrict__ input,
DST_DTYPE* __restrict__ output_q,
float* __restrict__ output_s,
const float scale_ub,
const int64_t hidden_size,
const int64_t num_tokens) {
const int warp_id = threadIdx.x / WARP_SIZE; // 07 (8 warps)
const int lane_id = threadIdx.x & (WARP_SIZE - 1); // 031
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
if (token_id >= num_tokens) return;
// Global tensors for this token
const T* token_input = input + token_id * hidden_size;
DST_DTYPE* token_output = output_q + token_id * hidden_size;
float* token_scale = output_s + token_id;
//
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_size
//
float max_value = 0.f;
using vec_t = AlignedVector<T, kVecSize>;
const int32_t num_vec_elems = hidden_size / kVecSize;
for (int32_t i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
vec_t input_vec;
Load(token_input + i * kVecSize, &input_vec);
#pragma unroll
for (uint32_t j = 0; j < kVecSize; ++j) {
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
}
}
float warp_max = warpReduceMax(max_value);
if (scale_ub > 0){
warp_max = fminf(warp_max, scale_ub);
}
float scale;
scale = warp_max / FP8_E4M3_MAX;
// Broadcast scale
if (lane_id == 0) {
token_scale[0] = scale;
}
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
//
// Pass-2: quantize and write back
//
for (int i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
vec_t input_vec;
Load(token_input + i * kVecSize, &input_vec);
DST_DTYPE output_arr[kVecSize];
#pragma unroll
for (uint32_t j = 0; j < kVecSize; ++j) {
float val = static_cast<float>(input_vec[j]) * scale_inv;
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
output_arr[j] = static_cast<DST_DTYPE>(val);
}
if constexpr (kVecSize == 16) {
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
} else {
// Use element-wise copy for vector size 8 to ensure correctness
for (int k = 0; k < kVecSize; ++k) {
token_output[i * kVecSize + k] = output_arr[k];
}
}
}
}
// ---------------------------------------------------------------------------
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
// ---------------------------------------------------------------------------
template <typename T, typename DST_DTYPE, int kVecSize = 16>
__global__ void per_token_quant_fp8_small_batch_kernel(
const T* __restrict__ input,
DST_DTYPE* __restrict__ output_q,
float* __restrict__ output_s,
const float scale_ub,
const int64_t hidden_size,
const int64_t num_tokens) {
const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return;
const int tid = threadIdx.x;
const int block_dim = blockDim.x;
const T* token_input = input + token_idx * hidden_size;
DST_DTYPE* token_output = output_q + token_idx * hidden_size;
float max_value = 0.0f;
// Use template parameter for vector size
using vec_t = AlignedVector<T, kVecSize>;
const int32_t num_vec_elems = hidden_size / kVecSize;
// Find max using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
Load(token_input + i * kVecSize, &input_vec);
#pragma unroll
for (uint32_t j = 0; j < kVecSize; ++j) {
float val = static_cast<float>(input_vec[j]);
max_value = fmaxf(max_value, fabsf(val));
}
}
max_value = blockReduceMax(max_value);
if (scale_ub > 0){
max_value = fminf(max_value, scale_ub);
}
__shared__ float scale;
if (tid == 0) {
scale = max_value / FP8_E4M3_MAX;
output_s[token_idx] = scale;
}
__syncthreads();
const float scale_inv = 1.0f / scale;
// Quantize using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
Load(token_input + i * kVecSize, &input_vec);
DST_DTYPE output_arr[kVecSize];
#pragma unroll
for (uint32_t j = 0; j < kVecSize; ++j) {
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
output_arr[j] = static_cast<DST_DTYPE>(val);
}
if constexpr (kVecSize == 16) {
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
} else {
// Use element-wise copy for vector size 8 to ensure correctness
for (int k = 0; k < kVecSize; ++k) {
token_output[i * kVecSize + k] = output_arr[k];
}
}
}
}
namespace fastdeploy {
template <typename scalar_t, typename fp8_type>
@@ -179,39 +331,78 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d]
auto rank = input.dims().size();
int const hidden_size = input.dims()[rank - 1];
int const num_tokens = input.numel() / hidden_size;
cudaStream_t stream = input.stream();
if (hidden_size % 8 == 0){
int device = 0;
cudaGetDevice(&device);
int sm_count = 0;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);
const int TOKENS_PER_CTA = 8;
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
const bool use_vec16 = (hidden_size % 16 == 0);
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
if (use_warp_kernel) {
// -------- warplocal ---------------------------------------------------
constexpr int THREADS = TOKENS_PER_CTA * WARP_SIZE; // 256
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
dim3 block(THREADS);
if (use_vec16) {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
reinterpret_cast<float*>(scales.data<float>()),
scale_ub,
hidden_size,
num_tokens);
} else {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
reinterpret_cast<float*>(scales.data<float>()),
scale_ub,
hidden_size,
num_tokens);
}
} else {
// -------- baseline -----------------------------------------------------
constexpr int THREADS = 256;
dim3 grid(num_tokens);
dim3 block(THREADS);
if (use_vec16) {
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
reinterpret_cast<float*>(scales.data<float>()),
scale_ub,
hidden_size,
num_tokens);
} else {
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
reinterpret_cast<float*>(scales.data<float>()),
scale_ub,
hidden_size,
num_tokens);
}
}
});
return;
}
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
cudaStream_t stream = input.stream();
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
input.data<scalar_t>(), scale_ub,
hidden_size);
});
switch (input.dtype()) {
case paddle::DataType::FLOAT32: {
using scalar_t = float;
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
input.data<scalar_t>(), scale_ub,
hidden_size);
break;
}
case paddle::DataType::FLOAT16: {
using scalar_t = phi::dtype::float16;
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
input.data<scalar_t>(), scale_ub,
hidden_size);
break;
}
case paddle::DataType::BFLOAT16: {
using scalar_t = phi::dtype::bfloat16;
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
input.data<scalar_t>(), scale_ub,
hidden_size);
break;
}
default:
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
}
}
PD_BUILD_STATIC_OP(static_scaled_fp8_quant)

View File

@@ -300,6 +300,7 @@ class EPRunner:
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(

View File

@@ -227,13 +227,14 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc":
gate_out, _, _ = get_moe_scores(
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
(

View File

@@ -490,6 +490,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(

View File

@@ -262,6 +262,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)

View File

@@ -32,6 +32,7 @@ try:
except ImportError:
pass
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant
class TritonWeightOnlyMoEMethod(QuantMethodBase):
@@ -258,8 +259,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
@@ -327,6 +328,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
per_channel_quant=False,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
@@ -379,6 +381,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
per_channel_quant=False,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
@@ -390,6 +393,377 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
return out
class Wfp8Afp8MoEMethod(QuantMethodBase):
"""
Use Triton Group Gemm to compute Fused wfp8afp8 Quant MoE.
"""
def __init__(self, quant_config):
"""
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"up_gate_proj_weight_scale",
"down_proj_weight_scale",
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
"""process_prequanted_weights"""
raise NotImplementedError
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Triton MoE create weight process.
"""
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
self.up_gate_proj_scale_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
1,
]
self.down_proj_scale_shape = [
layer.num_local_experts,
layer.hidden_size,
1,
]
if self.quant_config.is_checkpoint_bf16:
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
self.weight_dtype = paddle.float8_e4m3fn
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=self.up_gate_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=self.down_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = paddle.float8_e4m3fn
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = "float32"
# 2.crate tmp tensor
weight = paddle.empty(shape=weight_shape, dtype=weight_dtype)
scale = paddle.empty(shape=scale_shape, dtype=scale_dtype)
# 3.quantize weight
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
for expert_id in range(layer.num_experts):
weight_quant, scale[expert_id] = per_token_cast_to_fp8(
getattr(layer, weight_name)[expert_id].transpose([1, 0]).contiguous(),
)
weight[expert_id].copy_(weight_quant, False)
getattr(layer, weight_name).value().get_tensor()._clear()
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
"""
check layer is valid for this method
"""
assert up_gate_proj_weights[0].shape == [
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
assert down_proj_weights[0].shape == [
layer.hidden_size,
layer.moe_intermediate_size,
]
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_ids = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
False,
)
config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}
if token_num <= E:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
from .triton_moe_kernels import fused_moe_kernel_paddle
x_q, x_scale = scaled_fp8_quant(x, use_per_token_if_dynamic=True)
fused_moe_kernel_paddle[grid](
x_q,
layer.up_gate_proj_weight,
up_gate_proj_out,
x_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_possible_num_post_padded,
token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x_q.strides[0],
stride_ak=x_q.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[2],
stride_bn=layer.up_gate_proj_weight.strides[1],
stride_cm=up_gate_proj_out.strides[0],
stride_cn=up_gate_proj_out.strides[1],
#
stride_asm=x_scale.strides[0],
stride_ask=x_scale.strides[1],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=layer.up_gate_proj_weight_scale.strides[2],
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
top_k=top_k,
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
per_channel_quant=True,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
down_proj_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
* ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
)
x_q, x_scale = scaled_fp8_quant(down_proj_input, use_per_token_if_dynamic=True)
fused_moe_kernel_paddle[grid](
x_q,
layer.down_proj_weight,
down_proj_out,
x_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_possible_num_post_padded,
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=x_q.strides[0],
stride_ak=x_scale.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[2],
stride_bn=layer.down_proj_weight.strides[1],
stride_cm=down_proj_out.strides[0],
stride_cn=down_proj_out.strides[1],
stride_asm=x_scale.strides[0],
stride_ask=x_scale.strides[1],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=layer.down_proj_weight_scale.strides[2],
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
per_channel_quant=True,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
return out
class TensorWiseFP8MoEMethod(QuantMethodBase):
"""
Use Triton Group Gemm to compute Fused MoE.
@@ -524,6 +898,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
@@ -607,6 +982,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
per_channel_quant=False,
even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0,
)
@@ -676,6 +1052,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
per_channel_quant=False,
even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0,
)
@@ -945,6 +1322,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
@@ -1021,6 +1399,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
per_channel_quant=False,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
@@ -1074,6 +1453,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
per_channel_quant=False,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)

View File

@@ -66,6 +66,7 @@ def get_moe_scores(
top_k,
routed_scaling_factor,
e_score_correction_bias,
renormalize: bool = False,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
@@ -79,6 +80,7 @@ def get_moe_scores(
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
renormalize,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
@@ -93,6 +95,7 @@ class FusedMoE(nn.Layer):
self,
fd_config,
reduce_results: bool = True,
renormalize: bool = False,
moe_intermediate_size: int = -1,
num_experts: int = -1,
expert_id_offset: int = 0,
@@ -119,6 +122,7 @@ class FusedMoE(nn.Layer):
self.fd_config = fd_config
self.layer_idx = layer_idx
self.reduce_results = reduce_results
self.renormalize = renormalize
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.ep_size = fd_config.parallel_config.expert_parallel_size

View File

@@ -59,6 +59,7 @@ def fused_moe_kernel_paddle(
compute_type_enum: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
@@ -121,6 +122,13 @@ def fused_moe_kernel_paddle(
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
# channel-wise
elif per_channel_quant:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
else:
# (Zkk): every expert has one activation scale and weight scale.
a_scale = tl.load(a_scale_ptr + off_experts)

View File

@@ -23,6 +23,7 @@ from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.moe import FusedMoE
from fastdeploy.model_executor.layers.quantization.ops import (
cutlass_scaled_mm,
scaled_fp8_quant,
@@ -66,7 +67,14 @@ class WFP8AFP8Config(QuantConfigBase):
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
""" """
return WFP8AFP8LinearMethod(self)
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
Wfp8Afp8MoEMethod,
)
return Wfp8Afp8MoEMethod(self)
else:
return WFP8AFP8LinearMethod(self)
class WFP8AFP8LinearMethod(QuantMethodBase):

View File

@@ -116,6 +116,7 @@ class DeepSeekV3MoE(nn.Layer):
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.norm_topk_prob = fd_config.model_config.norm_topk_prob
weight_key_map = {
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
@@ -145,6 +146,7 @@ class DeepSeekV3MoE(nn.Layer):
self.experts = FusedMoE(
fd_config=fd_config,
reduce_results=False,
renormalize=self.norm_topk_prob,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.n_routed_experts,
top_k=fd_config.model_config.num_experts_per_tok,

View File

@@ -109,6 +109,8 @@ class Glm4Moe(nn.Layer):
self.n_routed_experts: int = fd_config.model_config.n_routed_experts
self.n_shared_experts: int = fd_config.model_config.n_shared_experts
self.norm_topk_prob = fd_config.model_config.norm_topk_prob
weight_key_map = {
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
@@ -133,6 +135,7 @@ class Glm4Moe(nn.Layer):
self.experts = FusedMoE(
fd_config,
reduce_results=False,
renormalize=self.norm_topk_prob,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.n_routed_experts,
top_k=fd_config.model_config.num_experts_per_tok,

View File

@@ -115,17 +115,16 @@ def setup_and_run_server():
"--max-model-len",
"32768",
"--max-num-seqs",
"32",
"1",
"--graph-optimization-config",
'{"use_cudagraph":true}',
"--load_choices",
"default_v1",
"--lm_head-fp32",
"--quantization",
'{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}',
"wfp8afp8",
]
env = os.environ.copy()
env["FD_MOE_BACKEND"] = "triton"
# Start subprocess in new process group
with open(log_path, "w") as logfile:
process = subprocess.Popen(

View File

@@ -15,6 +15,7 @@ class TestMoeRouting(unittest.TestCase):
self.topk_group = 4
self.top_k = 8
self.routed_scaling_factor = 1.5
self.renormalize = True
def node_limit_routing(self, gate_probs):
"""将所有专家分组, 只在topk_group个group内选择专家"""
@@ -64,6 +65,7 @@ class TestMoeRouting(unittest.TestCase):
self.topk_group,
self.top_k,
self.routed_scaling_factor,
self.renormalize,
)
ref_topk_values, ref_topk_idx = self.ref_moe_routing()