mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
Fix noaux_tc cuda Error 700 in CUDAGraph (#4174)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -571,6 +571,7 @@ std::vector<paddle::Tensor> NoauxTc(
|
|||||||
int n_group,
|
int n_group,
|
||||||
int topk_group,
|
int topk_group,
|
||||||
int topk,
|
int topk,
|
||||||
|
bool renormalize,
|
||||||
float routed_scaling_factor);
|
float routed_scaling_factor);
|
||||||
|
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
|
@@ -26,6 +26,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
|||||||
int n_group,
|
int n_group,
|
||||||
int topk_group,
|
int topk_group,
|
||||||
int topk,
|
int topk,
|
||||||
|
bool renormalize,
|
||||||
float routed_scaling_factor) {
|
float routed_scaling_factor) {
|
||||||
auto input_shape = scores_with_bias.shape();
|
auto input_shape = scores_with_bias.shape();
|
||||||
PD_CHECK(input_shape.size() == 2);
|
PD_CHECK(input_shape.size() == 2);
|
||||||
@@ -48,6 +49,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
|||||||
n_group,
|
n_group,
|
||||||
topk_group,
|
topk_group,
|
||||||
topk,
|
topk,
|
||||||
|
renormalize,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
stream);
|
stream);
|
||||||
|
|
||||||
@@ -76,6 +78,7 @@ PD_BUILD_STATIC_OP(noaux_tc)
|
|||||||
.Attrs({"n_group: int",
|
.Attrs({"n_group: int",
|
||||||
"topk_group: int",
|
"topk_group: int",
|
||||||
"topk:int",
|
"topk:int",
|
||||||
|
"renormalize: bool",
|
||||||
"routed_scaling_factor: float"})
|
"routed_scaling_factor: float"})
|
||||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||||
|
@@ -25,6 +25,23 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
|||||||
constexpr int32_t BLOCK_SIZE = 512;
|
constexpr int32_t BLOCK_SIZE = 512;
|
||||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||||
|
|
||||||
|
template <typename T_OUT, typename T_IN>
|
||||||
|
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||||
|
return __bfloat162float(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ inline T neg_inf() {
|
||||||
|
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
||||||
|
// so we need to cast from fp32
|
||||||
|
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||||
|
}
|
||||||
|
|
||||||
namespace warp_topk {
|
namespace warp_topk {
|
||||||
|
|
||||||
template <int size, typename T>
|
template <int size, typename T>
|
||||||
@@ -41,10 +58,21 @@ constexpr __host__ __device__ bool isPowerOf2(T v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool greater, typename T>
|
template <bool greater, typename T>
|
||||||
__device__ bool is_better_than(T val, T baseline) {
|
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
||||||
return (val > baseline && greater) || (val < baseline && !greater);
|
return (val > baseline && greater) || (val < baseline && !greater);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool greater, typename T, typename idxT>
|
||||||
|
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||||
|
idxT baseline_index) {
|
||||||
|
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
||||||
|
if (val == baseline) {
|
||||||
|
res = (index < baseline_index && greater) ||
|
||||||
|
(index < baseline_index && !greater);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename idxT>
|
template <typename T, typename idxT>
|
||||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||||
@@ -53,7 +81,8 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
|||||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int size, bool ascending, typename T, typename idxT>
|
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||||
|
bool is_stable>
|
||||||
struct BitonicMerge {
|
struct BitonicMerge {
|
||||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||||
__device__ static void merge(T* __restrict__ val_arr,
|
__device__ static void merge(T* __restrict__ val_arr,
|
||||||
@@ -67,7 +96,15 @@ struct BitonicMerge {
|
|||||||
int const other_i = i + stride;
|
int const other_i = i + stride;
|
||||||
T& val = val_arr[i];
|
T& val = val_arr[i];
|
||||||
T& other_val = val_arr[other_i];
|
T& other_val = val_arr[other_i];
|
||||||
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
|
bool is_better;
|
||||||
|
if constexpr (is_stable) {
|
||||||
|
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
||||||
|
idx_arr[other_i]);
|
||||||
|
} else {
|
||||||
|
is_better = is_better_than<ascending>(val, other_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_better) {
|
||||||
T tmp = val;
|
T tmp = val;
|
||||||
val = other_val;
|
val = other_val;
|
||||||
other_val = tmp;
|
other_val = tmp;
|
||||||
@@ -78,13 +115,14 @@ struct BitonicMerge {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
|
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
|
val_arr, idx_arr);
|
||||||
idx_arr + arr_len / 2);
|
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||||
|
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int size, bool ascending, typename T, typename idxT>
|
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
||||||
struct BitonicSort {
|
struct BitonicSort {
|
||||||
__device__ static void sort(T* __restrict__ val_arr,
|
__device__ static void sort(T* __restrict__ val_arr,
|
||||||
idxT* __restrict__ idx_arr) {
|
idxT* __restrict__ idx_arr) {
|
||||||
@@ -92,15 +130,16 @@ struct BitonicSort {
|
|||||||
static_assert(size >= 2 * WARP_SIZE);
|
static_assert(size >= 2 * WARP_SIZE);
|
||||||
constexpr int arr_len = size / WARP_SIZE;
|
constexpr int arr_len = size / WARP_SIZE;
|
||||||
|
|
||||||
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
|
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
||||||
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
|
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
||||||
idx_arr + arr_len / 2);
|
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||||
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
|
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
||||||
|
val_arr, idx_arr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <bool ascending, typename T, typename idxT>
|
template <bool ascending, typename T, typename idxT, bool is_stable>
|
||||||
struct BitonicSort<32, ascending, T, idxT> {
|
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||||
__device__ static void sort(T* __restrict__ val_arr,
|
__device__ static void sort(T* __restrict__ val_arr,
|
||||||
idxT* __restrict__ idx_arr) {
|
idxT* __restrict__ idx_arr) {
|
||||||
int const lane = threadIdx.x % WARP_SIZE;
|
int const lane = threadIdx.x % WARP_SIZE;
|
||||||
@@ -114,19 +153,37 @@ struct BitonicSort<32, ascending, T, idxT> {
|
|||||||
|
|
||||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||||
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
|
|
||||||
|
bool is_better;
|
||||||
|
if constexpr (is_stable) {
|
||||||
|
if constexpr (ascending) {
|
||||||
|
is_better = ((*val_arr > other) ||
|
||||||
|
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
||||||
|
(reverse != is_second);
|
||||||
|
} else {
|
||||||
|
is_better = ((*val_arr > other) ||
|
||||||
|
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
||||||
|
(reverse != is_second);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
is_better = (*val_arr != other &&
|
||||||
|
(*val_arr > other) != (reverse != is_second));
|
||||||
|
}
|
||||||
|
if (is_better) {
|
||||||
*val_arr = other;
|
*val_arr = other;
|
||||||
*idx_arr = other_idx;
|
*idx_arr = other_idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
|
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
||||||
|
idx_arr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <bool ascending, typename T, typename idxT>
|
template <bool ascending, bool reverse, typename T, typename idxT,
|
||||||
struct BitonicMerge<32, ascending, T, idxT> {
|
bool is_stable>
|
||||||
|
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||||
__device__ static void merge(T* __restrict__ val_arr,
|
__device__ static void merge(T* __restrict__ val_arr,
|
||||||
idxT* __restrict__ idx_arr) {
|
idxT* __restrict__ idx_arr) {
|
||||||
int const lane = threadIdx.x % WARP_SIZE;
|
int const lane = threadIdx.x % WARP_SIZE;
|
||||||
@@ -136,7 +193,24 @@ struct BitonicMerge<32, ascending, T, idxT> {
|
|||||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||||
idxT& idx = *idx_arr;
|
idxT& idx = *idx_arr;
|
||||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||||
if (val != other && ((val > other) == (ascending != is_second))) {
|
|
||||||
|
bool is_better;
|
||||||
|
if constexpr (is_stable) {
|
||||||
|
if constexpr (ascending) {
|
||||||
|
is_better = ((*val_arr > other) ||
|
||||||
|
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
||||||
|
(reverse != is_second); // for min
|
||||||
|
} else {
|
||||||
|
is_better = ((*val_arr > other) ||
|
||||||
|
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
||||||
|
(reverse != is_second); // for max
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
is_better =
|
||||||
|
(val != other && ((val > other) == (ascending != is_second)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_better) {
|
||||||
val = other;
|
val = other;
|
||||||
idx = other_idx;
|
idx = other_idx;
|
||||||
}
|
}
|
||||||
@@ -144,34 +218,42 @@ struct BitonicMerge<32, ascending, T, idxT> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int capacity, bool greater, typename T, typename idxT>
|
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||||
class WarpSort {
|
class WarpSort {
|
||||||
public:
|
public:
|
||||||
__device__ WarpSort(idxT k, T dummy)
|
__device__ WarpSort(idxT k, T dummy)
|
||||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||||
|
|
||||||
for (int i = 0; i < max_arr_len_; ++i) {
|
for (int i = 0; i < max_arr_len_; ++i) {
|
||||||
val_arr_[i] = dummy_;
|
val_arr_[i] = dummy_;
|
||||||
|
idx_arr_[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// load and merge k sorted values
|
// load and merge k sorted values
|
||||||
__device__ void load_sorted(T const* __restrict__ in,
|
__device__ void load_sorted(T const* __restrict__ in,
|
||||||
idxT const* __restrict__ in_idx,
|
idxT const* __restrict__ in_idx, idxT start) {
|
||||||
idxT start) {
|
|
||||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||||
if (idx < start + k_) {
|
if (idx < start + k_) {
|
||||||
T t = in[idx];
|
T t = in[idx];
|
||||||
if (is_better_than<greater>(t, val_arr_[i])) {
|
bool is_better;
|
||||||
|
if constexpr (is_stable) {
|
||||||
|
is_better =
|
||||||
|
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
||||||
|
} else {
|
||||||
|
is_better = is_better_than<greater>(t, val_arr_[i]);
|
||||||
|
}
|
||||||
|
if (is_better) {
|
||||||
val_arr_[i] = t;
|
val_arr_[i] = t;
|
||||||
idx_arr_[i] = in_idx[idx];
|
idx_arr_[i] = in_idx[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||||
|
val_arr_, idx_arr_);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||||
@@ -193,7 +275,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||||
|
|
||||||
T val_arr_[max_arr_len_];
|
T val_arr_[max_arr_len_];
|
||||||
@@ -205,11 +287,11 @@ protected:
|
|||||||
|
|
||||||
}; // end class WarpSort
|
}; // end class WarpSort
|
||||||
|
|
||||||
template <int capacity, bool greater, typename T, typename idxT>
|
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
|
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||||
public:
|
public:
|
||||||
__device__ WarpSelect(idxT k, T dummy)
|
__device__ WarpSelect(idxT k, T dummy)
|
||||||
: WarpSort<capacity, greater, T, idxT>(k, dummy),
|
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||||
k_th_(dummy),
|
k_th_(dummy),
|
||||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||||
@@ -234,7 +316,13 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
__device__ void add(T val, idxT idx) {
|
__device__ void add(T val, idxT idx) {
|
||||||
bool do_add = is_better_than<greater>(val, k_th_);
|
bool do_add;
|
||||||
|
if constexpr (is_stable) {
|
||||||
|
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
||||||
|
} else {
|
||||||
|
do_add = is_better_than<greater>(val, k_th_);
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||||
if (mask == 0) {
|
if (mask == 0) {
|
||||||
return;
|
return;
|
||||||
@@ -271,37 +359,52 @@ public:
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
__device__ void set_k_th_() {
|
__device__ void set_k_th_() {
|
||||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||||
|
if constexpr (is_stable) {
|
||||||
|
k_th_idx_ =
|
||||||
|
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void merge_buf_(T val, idxT idx) {
|
__device__ void merge_buf_(T val, idxT idx) {
|
||||||
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
|
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
||||||
|
|
||||||
T& old = val_arr_[max_arr_len_ - 1];
|
T& old = val_arr_[max_arr_len_ - 1];
|
||||||
if (is_better_than<greater>(val, old)) {
|
|
||||||
|
bool is_better;
|
||||||
|
if constexpr (is_stable) {
|
||||||
|
is_better =
|
||||||
|
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
||||||
|
} else {
|
||||||
|
is_better = is_better_than<greater>(val, old);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_better) {
|
||||||
old = val;
|
old = val;
|
||||||
idx_arr_[max_arr_len_ - 1] = idx;
|
idx_arr_[max_arr_len_ - 1] = idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||||
|
val_arr_, idx_arr_);
|
||||||
|
|
||||||
set_k_th_();
|
set_k_th_();
|
||||||
}
|
}
|
||||||
|
|
||||||
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
|
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
||||||
using WarpSort<capacity, greater, T, idxT>::val_arr_;
|
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
||||||
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
|
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
||||||
using WarpSort<capacity, greater, T, idxT>::lane_;
|
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
||||||
using WarpSort<capacity, greater, T, idxT>::k_;
|
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
||||||
using WarpSort<capacity, greater, T, idxT>::dummy_;
|
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
||||||
|
|
||||||
T* val_smem_;
|
T* val_smem_;
|
||||||
idxT* idx_smem_;
|
idxT* idx_smem_;
|
||||||
int smem_buf_len_ = 0;
|
int smem_buf_len_ = 0;
|
||||||
|
|
||||||
T k_th_;
|
T k_th_;
|
||||||
|
idxT k_th_idx_;
|
||||||
int const k_th_lane_;
|
int const k_th_lane_;
|
||||||
}; // end class WarpSelect
|
}; // end class WarpSelect
|
||||||
} // namespace warp_topk
|
} // namespace warp_topk
|
||||||
@@ -313,8 +416,8 @@ __device__ void topk_with_k2(T* output,
|
|||||||
int32_t const lane_id,
|
int32_t const lane_id,
|
||||||
int const num_experts_per_group) {
|
int const num_experts_per_group) {
|
||||||
// Get the top2 per thread
|
// Get the top2 per thread
|
||||||
T largest = cuda::std::numeric_limits<T>::min();
|
T largest = neg_inf<T>();
|
||||||
T second_largest = cuda::std::numeric_limits<T>::min();
|
T second_largest = neg_inf<T>();
|
||||||
|
|
||||||
if (num_experts_per_group > WARP_SIZE) {
|
if (num_experts_per_group > WARP_SIZE) {
|
||||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||||
@@ -368,8 +471,14 @@ __global__ void topk_with_k2_kernel(T* output,
|
|||||||
cg::thread_block block = cg::this_thread_block();
|
cg::thread_block block = cg::this_thread_block();
|
||||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.wait;");
|
||||||
|
#endif
|
||||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||||
}
|
}
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.launch_dependents;");
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename IdxT>
|
template <typename T, typename IdxT>
|
||||||
@@ -385,6 +494,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
int64_t const topk,
|
int64_t const topk,
|
||||||
int64_t const num_experts,
|
int64_t const num_experts,
|
||||||
int64_t const num_experts_per_group,
|
int64_t const num_experts_per_group,
|
||||||
|
bool const renormalize,
|
||||||
double routed_scaling_factor) {
|
double routed_scaling_factor) {
|
||||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||||
@@ -403,19 +513,29 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
|
|
||||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||||
// store the target topk idx
|
// store the target topk idx
|
||||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
|
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||||
T* s_topk_value =
|
T* s_topk_value =
|
||||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||||
warp_id * topk;
|
warp_id * topk;
|
||||||
|
s_topk_idx += warp_id * topk;
|
||||||
|
|
||||||
T value = cuda::std::numeric_limits<T>::min();
|
T value = neg_inf<T>();
|
||||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
T topk_group_value = neg_inf<T>();
|
||||||
int32_t num_equalto_topkth_group;
|
int32_t num_equalto_topkth_group;
|
||||||
|
|
||||||
if ((n_group > topk_group) && (case_id < num_tokens)) {
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||||
|
// acqbulk because it's ptr arithmetic
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (case_id < num_tokens) {
|
||||||
// calculate group_idx
|
// calculate group_idx
|
||||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||||
if (lane_id < n_group) {
|
if (lane_id < n_group &&
|
||||||
|
(isfinite(cuda_cast<float, T>(
|
||||||
|
group_scores[lane_id])))) // The check is necessary to avoid
|
||||||
|
// abnormal input
|
||||||
|
{
|
||||||
value = group_scores[lane_id];
|
value = group_scores[lane_id];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -426,22 +546,23 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||||
if (value == topk_group_value) {
|
if (value == topk_group_value) {
|
||||||
value = cuda::std::numeric_limits<T>::min();
|
value = neg_inf<T>();
|
||||||
}
|
}
|
||||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||||
count_equal_to_top_value = __popc(__ballot_sync(
|
count_equal_to_top_value = __popc(__ballot_sync(
|
||||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||||
}
|
}
|
||||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
|
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
/* is_stable */ true>
|
||||||
|
queue((int32_t)topk, neg_inf<T>());
|
||||||
|
|
||||||
int count_equalto_topkth_group = 0;
|
int count_equalto_topkth_group = 0;
|
||||||
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
|
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
|
||||||
if (case_id < num_tokens) {
|
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||||
if ((group_scores[i_group] > topk_group_value) ||
|
if ((group_scores[i_group] > topk_group_value) ||
|
||||||
((group_scores[i_group] == topk_group_value) &&
|
((group_scores[i_group] == topk_group_value) &&
|
||||||
@@ -449,9 +570,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
int32_t offset = i_group * num_experts_per_group;
|
int32_t offset = i_group * num_experts_per_group;
|
||||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||||
i += WARP_SIZE) {
|
i += WARP_SIZE) {
|
||||||
T candidates = i < num_experts_per_group
|
T candidates =
|
||||||
? scores_with_bias[offset + i]
|
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||||
: cuda::std::numeric_limits<T>::min();
|
scores_with_bias[offset + i]))
|
||||||
|
? scores_with_bias[offset + i]
|
||||||
|
: neg_inf<T>();
|
||||||
queue.add(candidates, offset + i);
|
queue.add(candidates, offset + i);
|
||||||
}
|
}
|
||||||
if (group_scores[i_group] == topk_group_value) {
|
if (group_scores[i_group] == topk_group_value) {
|
||||||
@@ -469,7 +592,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
// Load the valid score value
|
// Load the valid score value
|
||||||
// Calculate the summation
|
// Calculate the summation
|
||||||
float topk_sum = 1e-20;
|
float topk_sum = 1e-20;
|
||||||
if (case_id < num_tokens) {
|
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||||
for (int i = lane_id;
|
for (int i = lane_id;
|
||||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||||
i += WARP_SIZE) {
|
i += WARP_SIZE) {
|
||||||
@@ -478,33 +601,45 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
if (i < topk) {
|
if (i < topk) {
|
||||||
s_topk_value[i] = value;
|
s_topk_value[i] = value;
|
||||||
}
|
}
|
||||||
topk_sum += reduce(tile, value, cg::plus<float>());
|
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (case_id < num_tokens) {
|
|
||||||
|
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||||
scores[i] = 0;
|
scores[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__threadfence();
|
__syncwarp();
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (case_id < num_tokens) {
|
if (case_id < num_tokens) {
|
||||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
if (if_proceed_next_topk) {
|
||||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||||
scores[s_topk_idx[i]] = value;
|
float value;
|
||||||
if (if_proceed_next_topk) {
|
if (renormalize) {
|
||||||
|
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||||
|
routed_scaling_factor;
|
||||||
|
} else {
|
||||||
|
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||||
|
}
|
||||||
|
scores[s_topk_idx[i]] = value;
|
||||||
topk_indices[i] = s_topk_idx[i];
|
topk_indices[i] = s_topk_idx[i];
|
||||||
topk_values[i] = static_cast<T>(value);
|
topk_values[i] = cuda_cast<T, float>(value);
|
||||||
}
|
}
|
||||||
else {
|
} else {
|
||||||
|
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||||
topk_indices[i] = i;
|
topk_indices[i] = i;
|
||||||
topk_values[i] = static_cast<float>(1.0f / topk);
|
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||||
|
// default result.
|
||||||
}
|
}
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.launch_dependents;");
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename IdxT>
|
template <typename T, typename IdxT>
|
||||||
@@ -518,17 +653,24 @@ void invokeNoAuxTc(T* scores,
|
|||||||
int64_t const n_group,
|
int64_t const n_group,
|
||||||
int64_t const topk_group,
|
int64_t const topk_group,
|
||||||
int64_t const topk,
|
int64_t const topk,
|
||||||
|
bool const renormalize,
|
||||||
double const routed_scaling_factor,
|
double const routed_scaling_factor,
|
||||||
cudaStream_t const stream) {
|
cudaStream_t const stream) {
|
||||||
int64_t num_cases = num_tokens * n_group;
|
int64_t num_cases = num_tokens * n_group;
|
||||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||||
group_scores,
|
cudaLaunchConfig_t config;
|
||||||
scores_with_bias,
|
config.gridDim = topk_with_k2_num_blocks;
|
||||||
num_tokens,
|
config.blockDim = BLOCK_SIZE;
|
||||||
num_cases,
|
config.dynamicSmemBytes = 0;
|
||||||
n_group,
|
config.stream = stream;
|
||||||
num_experts / n_group);
|
cudaLaunchAttribute attrs[1];
|
||||||
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
|
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||||
|
config.numAttrs = 1;
|
||||||
|
config.attrs = attrs;
|
||||||
|
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||||
|
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||||
|
|
||||||
int64_t topk_with_k_group_num_blocks =
|
int64_t topk_with_k_group_num_blocks =
|
||||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||||
@@ -536,21 +678,19 @@ void invokeNoAuxTc(T* scores,
|
|||||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||||
topk);
|
topk);
|
||||||
|
|
||||||
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
|
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||||
BLOCK_SIZE,
|
config.gridDim = topk_with_k_group_num_blocks;
|
||||||
dynamic_smem_in_bytes,
|
config.blockDim = BLOCK_SIZE;
|
||||||
stream>>>(scores,
|
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||||
group_scores,
|
config.stream = stream;
|
||||||
topk_values,
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
topk_indices,
|
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||||
scores_with_bias,
|
config.numAttrs = 1;
|
||||||
num_tokens,
|
config.attrs = attrs;
|
||||||
n_group,
|
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||||
topk_group,
|
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||||
topk,
|
n_group, topk_group, topk, num_experts,
|
||||||
num_experts,
|
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||||
num_experts / n_group,
|
|
||||||
routed_scaling_factor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||||
@@ -564,6 +704,7 @@ void invokeNoAuxTc(T* scores,
|
|||||||
int64_t const n_group, \
|
int64_t const n_group, \
|
||||||
int64_t const topk_group, \
|
int64_t const topk_group, \
|
||||||
int64_t const topk, \
|
int64_t const topk, \
|
||||||
|
bool const renormalize, \
|
||||||
double const routed_scaling_factor, \
|
double const routed_scaling_factor, \
|
||||||
cudaStream_t const stream);
|
cudaStream_t const stream);
|
||||||
|
|
||||||
|
@@ -369,6 +369,7 @@ class EPRunner:
|
|||||||
layer.top_k,
|
layer.top_k,
|
||||||
layer.routed_scaling_factor,
|
layer.routed_scaling_factor,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
|
getattr(layer, "renormalize", True),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
@@ -39,6 +39,7 @@ elif current_platform.is_iluvatar():
|
|||||||
moe_expert_reduce,
|
moe_expert_reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
@@ -226,15 +227,14 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
"""
|
"""
|
||||||
gate_out = gate(x.cast("float32"))
|
gate_out = gate(x.cast("float32"))
|
||||||
if layer.topk_method == "noaux_tc":
|
if layer.topk_method == "noaux_tc":
|
||||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
gate_out, topk_weights, topk_idx = get_moe_scores(
|
||||||
|
|
||||||
gate_out, _, _ = get_moe_scores(
|
|
||||||
gate_out,
|
gate_out,
|
||||||
layer.n_group,
|
layer.n_group,
|
||||||
layer.topk_group,
|
layer.topk_group,
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
layer.routed_scaling_factor,
|
layer.routed_scaling_factor,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
|
getattr(layer, "renormalize", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
|
@@ -512,6 +512,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
layer.top_k,
|
layer.top_k,
|
||||||
layer.routed_scaling_factor,
|
layer.routed_scaling_factor,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
|
getattr(layer, "renormalize", True),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
@@ -263,6 +263,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
layer.top_k,
|
layer.top_k,
|
||||||
layer.routed_scaling_factor,
|
layer.routed_scaling_factor,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
|
getattr(layer, "renormalize", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
||||||
|
@@ -263,8 +263,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
layer.top_k,
|
layer.top_k,
|
||||||
layer.routed_scaling_factor,
|
layer.routed_scaling_factor,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
|
getattr(layer, "renormalize", True),
|
||||||
)
|
)
|
||||||
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
|
||||||
else:
|
else:
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
gate_out,
|
gate_out,
|
||||||
|
@@ -66,6 +66,7 @@ def get_moe_scores(
|
|||||||
top_k,
|
top_k,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
e_score_correction_bias,
|
e_score_correction_bias,
|
||||||
|
renormalize: bool = False,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
compute moe scores using e_score_correction_bias.
|
compute moe scores using e_score_correction_bias.
|
||||||
@@ -79,6 +80,7 @@ def get_moe_scores(
|
|||||||
n_group if n_group > 0 else 1,
|
n_group if n_group > 0 else 1,
|
||||||
topk_group if topk_group > 0 else 1,
|
topk_group if topk_group > 0 else 1,
|
||||||
top_k,
|
top_k,
|
||||||
|
renormalize,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
)
|
)
|
||||||
return scores, topk_values, topk_idx
|
return scores, topk_values, topk_idx
|
||||||
@@ -93,6 +95,7 @@ class FusedMoE(nn.Layer):
|
|||||||
self,
|
self,
|
||||||
fd_config,
|
fd_config,
|
||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
|
renormalize: bool = False,
|
||||||
moe_intermediate_size: int = -1,
|
moe_intermediate_size: int = -1,
|
||||||
num_experts: int = -1,
|
num_experts: int = -1,
|
||||||
expert_id_offset: int = 0,
|
expert_id_offset: int = 0,
|
||||||
@@ -119,6 +122,7 @@ class FusedMoE(nn.Layer):
|
|||||||
self.fd_config = fd_config
|
self.fd_config = fd_config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
|
self.renormalize = renormalize
|
||||||
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
|
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||||
|
@@ -121,6 +121,7 @@ class DeepSeekV3MoE(nn.Layer):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
self.norm_topk_prob = fd_config.model_config.norm_topk_prob
|
||||||
|
|
||||||
weight_key_map = {
|
weight_key_map = {
|
||||||
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
||||||
@@ -150,6 +151,7 @@ class DeepSeekV3MoE(nn.Layer):
|
|||||||
self.experts = FusedMoE(
|
self.experts = FusedMoE(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
|
renormalize=self.norm_topk_prob,
|
||||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||||
num_experts=fd_config.model_config.n_routed_experts,
|
num_experts=fd_config.model_config.n_routed_experts,
|
||||||
top_k=fd_config.model_config.num_experts_per_tok,
|
top_k=fd_config.model_config.num_experts_per_tok,
|
||||||
|
@@ -110,6 +110,8 @@ class Glm4Moe(nn.Layer):
|
|||||||
self.n_routed_experts: int = fd_config.model_config.n_routed_experts
|
self.n_routed_experts: int = fd_config.model_config.n_routed_experts
|
||||||
self.n_shared_experts: int = fd_config.model_config.n_shared_experts
|
self.n_shared_experts: int = fd_config.model_config.n_shared_experts
|
||||||
|
|
||||||
|
self.norm_topk_prob = fd_config.model_config.norm_topk_prob
|
||||||
|
|
||||||
weight_key_map = {
|
weight_key_map = {
|
||||||
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
||||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||||
@@ -134,6 +136,7 @@ class Glm4Moe(nn.Layer):
|
|||||||
self.experts = FusedMoE(
|
self.experts = FusedMoE(
|
||||||
fd_config,
|
fd_config,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
|
renormalize=self.norm_topk_prob,
|
||||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||||
num_experts=fd_config.model_config.n_routed_experts,
|
num_experts=fd_config.model_config.n_routed_experts,
|
||||||
top_k=fd_config.model_config.num_experts_per_tok,
|
top_k=fd_config.model_config.num_experts_per_tok,
|
||||||
|
@@ -2,74 +2,103 @@ import unittest
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
|
|
||||||
|
|
||||||
class TestMoeRouting(unittest.TestCase):
|
class TestMoeRouting(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.num_tokens = 10
|
paddle.seed(2024)
|
||||||
self.num_experts = 64
|
print(paddle.device.cuda.get_device_properties())
|
||||||
self.gating_output = paddle.rand([self.num_tokens, self.num_experts])
|
print(paddle.__git_commit__)
|
||||||
self.e_score_correction_bias = paddle.rand([self.num_experts])
|
|
||||||
self.n_group = 8
|
|
||||||
self.topk_group = 4
|
|
||||||
self.top_k = 8
|
|
||||||
self.routed_scaling_factor = 1.5
|
|
||||||
|
|
||||||
def node_limit_routing(self, gate_probs):
|
def native_group_topk(
|
||||||
"""将所有专家分组, 只在topk_group个group内选择专家"""
|
self,
|
||||||
assert len(gate_probs.shape) == 2
|
gating_output: paddle.Tensor,
|
||||||
seq_length, n_experts = gate_probs.shape
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
routed_scaling_factor: float,
|
||||||
|
e_score_correction_bias: paddle.Tensor,
|
||||||
|
):
|
||||||
|
original_scores = paddle.nn.functional.sigmoid(gating_output)
|
||||||
|
if len(e_score_correction_bias.shape) == 1:
|
||||||
|
e_score_correction_bias = e_score_correction_bias.unsqueeze(0)
|
||||||
|
scores = original_scores + e_score_correction_bias
|
||||||
|
|
||||||
group_scores = gate_probs.reshape([seq_length, 8, -1]).topk(2, axis=-1)[0].sum(axis=-1)
|
num_token, n_experts = scores.shape
|
||||||
group_idx = paddle.topk(group_scores, k=4, axis=-1, sorted=True)[1]
|
group_scores = scores.reshape([num_token, num_expert_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
|
||||||
group_mask = paddle.zeros_like(group_scores).put_along_axis(
|
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group]
|
||||||
group_idx, paddle.ones([], dtype="float32"), axis=-1
|
group_mask = paddle.zeros_like(group_scores) # [n, n_group]
|
||||||
|
group_mask.put_along_axis_(group_idx, 1.0, axis=-1) # [n, n_group]
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand([num_token, num_expert_group, n_experts // num_expert_group])
|
||||||
|
.reshape([num_token, -1])
|
||||||
)
|
)
|
||||||
score_mask = group_mask.unsqueeze(-1).expand([seq_length, 8, n_experts // 8]).reshape([seq_length, -1])
|
tmp_scores = scores.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
|
||||||
gate_probs = gate_probs.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
|
|
||||||
return gate_probs
|
|
||||||
|
|
||||||
def ref_moe_routing(self):
|
topk_ids = paddle.topk(tmp_scores, topk, axis=1)[1]
|
||||||
scores = paddle.nn.functional.sigmoid(self.gating_output)
|
topk_weights = paddle.take_along_axis(original_scores, topk_ids, axis=1)
|
||||||
prob_for_choice = scores + self.e_score_correction_bias.unsqueeze(0)
|
|
||||||
prob_for_choice = self.node_limit_routing(prob_for_choice)
|
|
||||||
top_logits, topk_idx_ref = paddle.topk(prob_for_choice, self.top_k, axis=1)
|
|
||||||
|
|
||||||
token_num, top_k = topk_idx_ref.shape
|
if renormalize:
|
||||||
_, num_expert = prob_for_choice.shape
|
topk_weights = topk_weights / paddle.sum(topk_weights, axis=1, keepdim=True)
|
||||||
topk_idx_expanded = paddle.unsqueeze(topk_idx_ref, axis=-1)
|
|
||||||
indices = paddle.concat(
|
|
||||||
[
|
|
||||||
paddle.arange(token_num, dtype="int64").unsqueeze(1).tile([1, top_k]).unsqueeze(-1),
|
|
||||||
topk_idx_expanded,
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
selected_gate_probs = paddle.gather_nd(scores, indices)
|
|
||||||
|
|
||||||
selected_gate_probs_sum = paddle.sum(selected_gate_probs, axis=1, keepdim=True)
|
if routed_scaling_factor != 1.0:
|
||||||
topk_weights_ref = selected_gate_probs / selected_gate_probs_sum
|
topk_weights = topk_weights * routed_scaling_factor
|
||||||
topk_weights_ref = topk_weights_ref * self.routed_scaling_factor
|
|
||||||
return topk_weights_ref, topk_idx_ref
|
|
||||||
|
|
||||||
def test_moe_select(self):
|
return topk_weights, topk_ids
|
||||||
scores = paddle.nn.functional.sigmoid(self.gating_output)
|
|
||||||
scores_with_bias = scores + self.e_score_correction_bias.unsqueeze(0)
|
|
||||||
|
|
||||||
scores, topk_values, topk_idx = noaux_tc(
|
def test_group_topk(self):
|
||||||
scores,
|
|
||||||
scores_with_bias,
|
|
||||||
self.n_group,
|
|
||||||
self.topk_group,
|
|
||||||
self.top_k,
|
|
||||||
self.routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
ref_topk_values, ref_topk_idx = self.ref_moe_routing()
|
renormalize = True
|
||||||
|
|
||||||
paddle.allclose(topk_values, ref_topk_values)
|
test_cases = [
|
||||||
paddle.allclose(topk_idx.cast(int), ref_topk_idx.cast(int))
|
# (num_experts, n_group, topk_group, top_k, routed_scaling_factor)
|
||||||
|
(128, 1, 1, 8, 1.0), # glm45-air
|
||||||
|
(256, 8, 4, 8, 2.5), # deepseek
|
||||||
|
]
|
||||||
|
|
||||||
|
for case_tuple in test_cases:
|
||||||
|
num_experts, n_group, topk_group, top_k, routed_scaling_factor = case_tuple
|
||||||
|
for num_tokens in [1, 32, 64, 128]:
|
||||||
|
gating_output = paddle.rand([num_tokens, num_experts])
|
||||||
|
e_score_correction_bias = paddle.rand([1, num_experts])
|
||||||
|
|
||||||
|
ref_topk_values, ref_topk_idx = self.native_group_topk(
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
num_expert_group=n_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_score, topk_values, topk_idx = get_moe_scores(
|
||||||
|
gating_output=gating_output,
|
||||||
|
n_group=n_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
top_k=top_k,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
|
||||||
|
equal_topk_ids = paddle.allclose(
|
||||||
|
topk_idx.cast("int32"), ref_topk_idx.cast("int32"), atol=0.0, rtol=0.0
|
||||||
|
).item()
|
||||||
|
print(
|
||||||
|
f"Test Case[{case_tuple}], num_tokens = {num_tokens}, equal_topk_value: {equal_topk_value}, equal_topk_ids: {equal_topk_ids}"
|
||||||
|
)
|
||||||
|
if not equal_topk_value:
|
||||||
|
print(f"ref_topk_values = {ref_topk_values}")
|
||||||
|
print(f"topk_values = {topk_values}")
|
||||||
|
if not equal_topk_ids:
|
||||||
|
print(f"ref_topk_idx = {ref_topk_idx}")
|
||||||
|
print(f"topk_idx = {topk_idx}")
|
||||||
|
assert equal_topk_value and equal_topk_ids
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Reference in New Issue
Block a user