mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
Compare commits
10 Commits
fix-gpu-me
...
release/2.
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e42dc8c694 | ||
![]() |
63a03ee152 | ||
![]() |
9cc2c99539 | ||
![]() |
31e32b5821 | ||
![]() |
aebe12a58d | ||
![]() |
8fdb950e9f | ||
![]() |
a460462d2a | ||
![]() |
cb8d87b945 | ||
![]() |
de4feff147 | ||
![]() |
f38b174a75 |
@@ -564,6 +564,7 @@ std::vector<paddle::Tensor> NoauxTc(
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
@@ -615,6 +616,8 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
||||
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
void clear_ipc_handles(int64_t _fa);
|
||||
|
||||
// speculative decoding Kernel
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
@@ -1203,6 +1206,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
|
||||
|
||||
m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles");
|
||||
|
||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||
|
@@ -122,10 +122,14 @@ void register_graph_buffers(fptr_t _fa,
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
void clear_ipc_handles(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
fa->clear_ipc_handles();
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
|
@@ -517,10 +517,15 @@ class CustomAllreduce {
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
void clear_ipc_handles(){
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
ipc_handles_.clear();
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
clear_ipc_handles();
|
||||
}
|
||||
};
|
||||
} // namespace paddle
|
||||
|
@@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) {
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef FP8_E4M3_MAX
|
||||
#define FP8_E4M3_MAX 448.0
|
||||
#endif
|
||||
|
||||
#ifndef DISPATCH_FLOAT_FP6_DTYPE
|
||||
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
|
||||
switch (pd_dtype) { \
|
||||
case phi::DataType::FLOAT32: { \
|
||||
using c_type = float; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::BFLOAT16: { \
|
||||
using c_type = phi::dtype::bfloat16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::FLOAT16: { \
|
||||
using c_type = phi::dtype::float16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1)
|
||||
return num;
|
||||
@@ -563,3 +591,28 @@ inline int GetSMVersion() {
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
@@ -26,6 +26,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
PD_CHECK(input_shape.size() == 2);
|
||||
@@ -48,6 +49,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
|
||||
@@ -76,6 +78,7 @@ PD_BUILD_STATIC_OP(noaux_tc)
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"renormalize: bool",
|
||||
"routed_scaling_factor: float"})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||
|
@@ -25,6 +25,23 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T neg_inf() {
|
||||
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
||||
// so we need to cast from fp32
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
@@ -41,10 +58,21 @@ constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__device__ bool is_better_than(T val, T baseline) {
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <bool greater, typename T, typename idxT>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||
idxT baseline_index) {
|
||||
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
||||
if (val == baseline) {
|
||||
res = (index < baseline_index && greater) ||
|
||||
(index < baseline_index && !greater);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
@@ -53,7 +81,8 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
@@ -67,7 +96,15 @@ struct BitonicMerge {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
||||
idx_arr[other_i]);
|
||||
} else {
|
||||
is_better = is_better_than<ascending>(val, other_val);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
@@ -78,13 +115,14 @@ struct BitonicMerge {
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
@@ -92,15 +130,16 @@ struct BitonicSort {
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort<32, ascending, T, idxT> {
|
||||
template <bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
@@ -114,19 +153,37 @@ struct BitonicSort<32, ascending, T, idxT> {
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
||||
(reverse != is_second);
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
||||
(reverse != is_second);
|
||||
}
|
||||
} else {
|
||||
is_better = (*val_arr != other &&
|
||||
(*val_arr > other) != (reverse != is_second));
|
||||
}
|
||||
if (is_better) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
||||
idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge<32, ascending, T, idxT> {
|
||||
template <bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
@@ -136,7 +193,24 @@ struct BitonicMerge<32, ascending, T, idxT> {
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
if (val != other && ((val > other) == (ascending != is_second))) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
||||
(reverse != is_second); // for min
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
||||
(reverse != is_second); // for max
|
||||
}
|
||||
} else {
|
||||
is_better =
|
||||
(val != other && ((val > other) == (ascending != is_second)));
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
@@ -144,34 +218,42 @@ struct BitonicMerge<32, ascending, T, idxT> {
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSort {
|
||||
public:
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
idx_arr_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx,
|
||||
idxT start) {
|
||||
idxT const* __restrict__ in_idx, idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
if (is_better_than<greater>(t, val_arr_[i])) {
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(t, val_arr_[i]);
|
||||
}
|
||||
if (is_better) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
@@ -193,7 +275,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
@@ -205,11 +287,11 @@ protected:
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
|
||||
public:
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT>(k, dummy),
|
||||
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
@@ -234,7 +316,13 @@ public:
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add = is_better_than<greater>(val, k_th_);
|
||||
bool do_add;
|
||||
if constexpr (is_stable) {
|
||||
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
||||
} else {
|
||||
do_add = is_better_than<greater>(val, k_th_);
|
||||
}
|
||||
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
@@ -271,37 +359,52 @@ public:
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
if constexpr (is_stable) {
|
||||
k_th_idx_ =
|
||||
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
if (is_better_than<greater>(val, old)) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(val, old);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT>::dummy_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
idxT k_th_idx_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
@@ -313,8 +416,8 @@ __device__ void topk_with_k2(T* output,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = cuda::std::numeric_limits<T>::min();
|
||||
T second_largest = cuda::std::numeric_limits<T>::min();
|
||||
T largest = neg_inf<T>();
|
||||
T second_largest = neg_inf<T>();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
@@ -368,8 +471,14 @@ __global__ void topk_with_k2_kernel(T* output,
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
@@ -385,6 +494,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int64_t const topk,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
bool const renormalize,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
@@ -403,19 +513,29 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
if ((n_group > topk_group) && (case_id < num_tokens)) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group) {
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
@@ -426,22 +546,23 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
|
||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||
if (case_id < num_tokens) {
|
||||
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
@@ -449,9 +570,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates = i < num_experts_per_group
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
@@ -469,7 +592,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens) {
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
@@ -478,33 +601,45 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, value, cg::plus<float>());
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (case_id < num_tokens) {
|
||||
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
__syncwarp();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||
scores[s_topk_idx[i]] = value;
|
||||
if (if_proceed_next_topk) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value;
|
||||
if (renormalize) {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||
routed_scaling_factor;
|
||||
} else {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||
}
|
||||
scores[s_topk_idx[i]] = value;
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = static_cast<T>(value);
|
||||
topk_values[i] = cuda_cast<T, float>(value);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = static_cast<float>(1.0f / topk);
|
||||
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
@@ -518,17 +653,24 @@ void invokeNoAuxTc(T* scores,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
bool const renormalize,
|
||||
double const routed_scaling_factor,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
@@ -536,21 +678,19 @@ void invokeNoAuxTc(T* scores,
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor);
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||
n_group, topk_group, topk, num_experts,
|
||||
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
@@ -564,6 +704,7 @@ void invokeNoAuxTc(T* scores,
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
bool const renormalize, \
|
||||
double const routed_scaling_factor, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
|
@@ -3,6 +3,158 @@
|
||||
|
||||
#include "quantization/common.cuh"
|
||||
|
||||
// adapted from: https://github.com/sgl-project/sglang/blob/v0.5.2rc2/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Warp‑local, no shared memory
|
||||
// • One warp handles one token.
|
||||
// • Eight tokens per 256‑thread CTA.
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const float scale_ub,
|
||||
const int64_t hidden_size,
|
||||
const int64_t num_tokens) {
|
||||
const int warp_id = threadIdx.x / WARP_SIZE; // 0‑7 (8 warps)
|
||||
const int lane_id = threadIdx.x & (WARP_SIZE - 1); // 0‑31
|
||||
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
|
||||
if (token_id >= num_tokens) return;
|
||||
|
||||
// Global tensors for this token
|
||||
const T* token_input = input + token_id * hidden_size;
|
||||
DST_DTYPE* token_output = output_q + token_id * hidden_size;
|
||||
float* token_scale = output_s + token_id;
|
||||
|
||||
//
|
||||
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_size
|
||||
//
|
||||
float max_value = 0.f;
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
|
||||
}
|
||||
}
|
||||
|
||||
float warp_max = warpReduceMax(max_value);
|
||||
if (scale_ub > 0){
|
||||
warp_max = fminf(warp_max, scale_ub);
|
||||
}
|
||||
float scale;
|
||||
scale = warp_max / FP8_E4M3_MAX;
|
||||
// Broadcast scale
|
||||
if (lane_id == 0) {
|
||||
token_scale[0] = scale;
|
||||
}
|
||||
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
|
||||
|
||||
//
|
||||
// Pass-2: quantize and write back
|
||||
//
|
||||
for (int i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]) * scale_inv;
|
||||
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
}
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const float scale_ub,
|
||||
const int64_t hidden_size,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= num_tokens) return;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int block_dim = blockDim.x;
|
||||
|
||||
const T* token_input = input + token_idx * hidden_size;
|
||||
DST_DTYPE* token_output = output_q + token_idx * hidden_size;
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
// Use template parameter for vector size
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||
|
||||
// Find max using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
if (scale_ub > 0){
|
||||
max_value = fminf(max_value, scale_ub);
|
||||
}
|
||||
__shared__ float scale;
|
||||
if (tid == 0) {
|
||||
scale = max_value / FP8_E4M3_MAX;
|
||||
output_s[token_idx] = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_inv = 1.0f / scale;
|
||||
|
||||
// Quantize using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
}
|
||||
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
@@ -179,39 +331,78 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d]
|
||||
auto rank = input.dims().size();
|
||||
int const hidden_size = input.dims()[rank - 1];
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
if (hidden_size % 8 == 0){
|
||||
int device = 0;
|
||||
cudaGetDevice(&device);
|
||||
int sm_count = 0;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);
|
||||
const int TOKENS_PER_CTA = 8;
|
||||
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
|
||||
const bool use_vec16 = (hidden_size % 16 == 0);
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
if (use_warp_kernel) {
|
||||
// -------- warp‑local ---------------------------------------------------
|
||||
constexpr int THREADS = TOKENS_PER_CTA * WARP_SIZE; // 256
|
||||
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
}
|
||||
} else {
|
||||
// -------- baseline -----------------------------------------------------
|
||||
constexpr int THREADS = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
|
||||
cudaStream_t stream = input.stream();
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
});
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
using scalar_t = float;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using scalar_t = phi::dtype::float16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using scalar_t = phi::dtype::bfloat16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(static_scaled_fp8_quant)
|
||||
|
@@ -201,12 +201,12 @@ class CacheTransferManager:
|
||||
def _init_gpu_cache(self, args):
|
||||
|
||||
if not args.create_cache_tensor:
|
||||
logger.info("Waiting for runners to create kv cache.")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners to create kv cache.")
|
||||
while self.cache_ready_signal.value[self.rank] != 1:
|
||||
time.sleep(1)
|
||||
logger.info("OK! Stop waiting.")
|
||||
time.sleep(0.1)
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
|
||||
|
||||
logger.info("Initializing kv cache for all layers.")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
@@ -215,13 +215,13 @@ class CacheTransferManager:
|
||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
|
||||
if args.create_cache_tensor:
|
||||
logger.info(f"..creating kv cache for layer {i}: {cache_shape}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {cache_shape}")
|
||||
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
set_data_ipc(key_cache, key_name)
|
||||
set_data_ipc(val_cache, val_name)
|
||||
else:
|
||||
logger.info(f"..attaching kv cache for layer {i}: {cache_shape}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
|
||||
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||
@@ -233,20 +233,22 @@ class CacheTransferManager:
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
||||
|
||||
if args.create_cache_tensor:
|
||||
logger.info("✅ kv cache is ready!")
|
||||
logger.info("[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{self.device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
|
||||
)
|
||||
|
||||
def _init_cpu_cache(self, args):
|
||||
if args.num_cpu_blocks == 0:
|
||||
logger.info("💡 no swap space (cpu cache) is specified.")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
return
|
||||
logger.info("Initializing swap space (cpu cache) for all layers.")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.")
|
||||
paddle.set_device("cpu")
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
@@ -254,12 +256,14 @@ class CacheTransferManager:
|
||||
key_name = f"key_caches_{i}_rank{self.rank}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}"
|
||||
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||
logger.info(f"..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB")
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
|
||||
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
|
||||
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
||||
logger.info("✅ swap space (cpu cache) is ready!")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
def _do_swap_to_cpu_task(
|
||||
@@ -473,6 +477,10 @@ class CacheTransferManager:
|
||||
while True:
|
||||
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
||||
try:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] Start clearing caches {self.cache_ready_signal.value}"
|
||||
)
|
||||
# clear cpu caches
|
||||
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
paddle.set_device("cpu")
|
||||
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
||||
@@ -486,37 +494,58 @@ class CacheTransferManager:
|
||||
while np.sum(self.swap_space_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
|
||||
# clear gpu caches
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
for name, tensor in self.gpu_cache_kvs.items():
|
||||
unset_data_ipc(tensor, name, True, False)
|
||||
self.gpu_cache_kvs.clear()
|
||||
self.gpu_cache_k_tensors.clear()
|
||||
self.gpu_cache_v_tensors.clear()
|
||||
|
||||
# reset cache_ready_signal
|
||||
self.cache_ready_signal.value[self.rank] = 0
|
||||
if np.sum(self.cache_ready_signal.value) == 0:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] Finish clearing caches {self.cache_ready_signal.value}"
|
||||
)
|
||||
|
||||
# wait for all ranks caches to be cleared
|
||||
if np.sum(self.cache_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
|
||||
# reset kv_cache_status_signal
|
||||
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
|
||||
logger.info("All ranks finish clearing caches")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear caches: {e}")
|
||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to clear caches: {e}")
|
||||
|
||||
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
|
||||
try:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] Start restoring caches {self.cache_ready_signal.value}"
|
||||
)
|
||||
# restore cpu cache
|
||||
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._init_cpu_cache(args)
|
||||
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
|
||||
# restore gpu cache and set cache_ready_signal
|
||||
self._init_gpu_cache(args)
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] Finish restoring caches {self.cache_ready_signal.value}"
|
||||
)
|
||||
|
||||
# wait for all ranks caches to be ready
|
||||
while np.sum(self.cache_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
|
||||
# set kv_cache_status_signal
|
||||
logger.info("All ranks finish restoring caches")
|
||||
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restore caches: {e}")
|
||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to restore caches: {e}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
|
@@ -42,6 +42,12 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
|
||||
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
|
||||
|
||||
|
||||
def custom_ar_clear_ipc_handles():
|
||||
global _TP_AR
|
||||
if _TP_AR is not None:
|
||||
_TP_AR.clear_ipc_handles()
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@paddle.jit.marker.unified
|
||||
|
@@ -25,6 +25,7 @@ from paddle.distributed.communication.group import Group
|
||||
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
all_reduce,
|
||||
clear_ipc_handles,
|
||||
dispose,
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_all_reduce,
|
||||
@@ -220,6 +221,9 @@ class CustomAllreduce:
|
||||
else:
|
||||
return self.all_reduce(input, input, registered=False)
|
||||
|
||||
def clear_ipc_handles(self):
|
||||
clear_ipc_handles(self._ptr)
|
||||
|
||||
def close(self):
|
||||
if self._ptr:
|
||||
dispose(self._ptr)
|
||||
|
@@ -801,6 +801,19 @@ class EngineSevice:
|
||||
def check_and_free_block_tables(self):
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
|
||||
def clear_data(self):
|
||||
try:
|
||||
llm_logger.info("Clear Data: Start")
|
||||
self.token_processor.clear_data()
|
||||
self.engine_worker_queue.clear_data()
|
||||
self.send_response_server.req_dict.clear()
|
||||
self.recv_request_server.req_dict.clear()
|
||||
llm_logger.info("Clear Data: Successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Clear data error: {e}")
|
||||
return False
|
||||
|
||||
def _exit_sub_services(self):
|
||||
"""
|
||||
exit sub services
|
||||
|
@@ -222,7 +222,9 @@ class LLMEngine:
|
||||
if sampling_params is not None:
|
||||
request.sampling_params = sampling_params
|
||||
request.preprocess_start_time = time.time()
|
||||
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs") or {}
|
||||
chat_template_kwargs["chat_template"] = kwargs.get("chat_template")
|
||||
kwargs["chat_template_kwargs"] = chat_template_kwargs
|
||||
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
|
||||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||||
request.need_prefill_tokens = request.prompt_token_ids_len
|
||||
@@ -234,9 +236,6 @@ class LLMEngine:
|
||||
request.get("max_tokens"),
|
||||
),
|
||||
)
|
||||
if request.get("reasoning_max_tokens") is None:
|
||||
default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1)
|
||||
request.set("reasoning_max_tokens", default_reasoning_max_tokens)
|
||||
min_tokens = request.get("min_tokens")
|
||||
if input_ids_len + min_tokens >= self.cfg.max_model_len:
|
||||
error_msg = (
|
||||
|
@@ -159,8 +159,6 @@ class SamplingParams:
|
||||
def __post_init__(self):
|
||||
if self.seed is None:
|
||||
self.seed = random.randint(0, 922337203685477580)
|
||||
if self.max_tokens is not None and self.reasoning_max_tokens is None:
|
||||
self.reasoning_max_tokens = max(int(self.max_tokens * 0.8), 1)
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
|
@@ -512,6 +512,10 @@ class ResourceManagerV1(ResourceManager):
|
||||
def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
|
||||
return self.finish_execution_pool.submit(self.finish_requests, request_ids)
|
||||
|
||||
def clear_data(self):
|
||||
self.waiting: deque[Request] = deque()
|
||||
self.to_be_rescheduled_request_id_set = set()
|
||||
|
||||
def finish_requests(self, request_ids: Union[str, Iterable[str]]):
|
||||
llm_logger.info(f"recycle resources for requests: {request_ids}")
|
||||
try:
|
||||
|
@@ -141,6 +141,9 @@ class EngineClient:
|
||||
self.zmq_client = ZmqIpcClient(model, mode)
|
||||
self.zmq_client.connect()
|
||||
|
||||
def check_model_weight_status(self):
|
||||
return self.model_weights_status_signal.value[0] < 0
|
||||
|
||||
async def format_and_add_data(self, prompts: dict):
|
||||
"""
|
||||
Format the request data and send the request to the server.
|
||||
@@ -169,6 +172,9 @@ class EngineClient:
|
||||
|
||||
task["preprocess_start_time"] = time.time()
|
||||
try:
|
||||
chat_template_kwargs = task.get("chat_template_kwargs", {})
|
||||
chat_template_kwargs.update({"chat_template": task.get("chat_template"), "tools": task.get("tools")})
|
||||
task["chat_template_kwargs"] = chat_template_kwargs
|
||||
if inspect.iscoroutinefunction(self.data_processor.process_request_dict):
|
||||
await self.data_processor.process_request_dict(task, self.max_model_len)
|
||||
else:
|
||||
|
@@ -480,6 +480,7 @@ def reset_scheduler():
|
||||
|
||||
if llm_engine is None:
|
||||
return Response("Engine not loaded", status_code=500)
|
||||
llm_engine.engine.clear_data()
|
||||
llm_engine.engine.scheduler.reset()
|
||||
return Response("Scheduler Reset Successfully", status_code=200)
|
||||
|
||||
@@ -498,6 +499,7 @@ def control_scheduler(request: ControlSchedulerRequest):
|
||||
return JSONResponse(content=content.model_dump(), status_code=500)
|
||||
|
||||
if request.reset:
|
||||
llm_engine.engine.clear_data()
|
||||
llm_engine.engine.scheduler.reset()
|
||||
|
||||
if request.load_shards_num or request.reallocate_shard:
|
||||
|
@@ -210,6 +210,8 @@ class OpenAIServingChat:
|
||||
decoder_base_url=self.tokenizer_base_url,
|
||||
)
|
||||
while num_choices > 0:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
raise ValueError("Engine is clearing model weight")
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
@@ -425,6 +427,8 @@ class OpenAIServingChat:
|
||||
decoder_base_url=self.tokenizer_base_url,
|
||||
)
|
||||
while True:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
raise ValueError("Engine is clearing model weight")
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
|
@@ -216,6 +216,8 @@ class OpenAIServingCompletion:
|
||||
completion_batched_token_ids = [[] for _ in range(num_choices)]
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
raise ValueError("Engine is clearing model weight")
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
@@ -333,6 +335,8 @@ class OpenAIServingCompletion:
|
||||
)
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
raise ValueError("Engine is clearing model weight")
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
|
@@ -88,7 +88,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
str: error message
|
||||
"""
|
||||
data_processor_logger.info(f"Start processing request: {request}")
|
||||
request.chat_template = kwargs.get("chat_template")
|
||||
request = self._apply_default_parameters(request)
|
||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||
request.eos_token_ids = self.eos_token_ids
|
||||
@@ -127,7 +126,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
)
|
||||
elif request.messages is not None:
|
||||
task = request.to_dict()
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs")
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
||||
if chat_template_kwargs:
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
for k, v in chat_template_kwargs.items():
|
||||
@@ -135,7 +134,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
task[k] = v
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
request.prompt_token_ids = self.messages2ids(task)
|
||||
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
|
||||
else:
|
||||
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
|
||||
|
||||
@@ -205,7 +204,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
req_id = request.get("request_id", None)
|
||||
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
|
||||
elif request.get("messages"):
|
||||
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||
if chat_template_kwargs:
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
for k, v in chat_template_kwargs.items():
|
||||
@@ -213,7 +212,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request[k] = v
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
request["prompt_token_ids"] = self.messages2ids(request)
|
||||
request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||
|
||||
@@ -379,7 +378,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
del self.tool_parser_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def messages2ids(self, request_or_messages):
|
||||
def messages2ids(self, request_or_messages, **kwargs):
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
|
||||
@@ -397,7 +396,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
tokenize=False,
|
||||
split_special_tokens=False,
|
||||
add_special_tokens=False,
|
||||
chat_template=request_or_messages.get("chat_template", None),
|
||||
**kwargs,
|
||||
)
|
||||
request_or_messages["text_after_process"] = spliced_message
|
||||
req_id = None
|
||||
|
@@ -113,7 +113,6 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""process the input data"""
|
||||
request.chat_template = kwargs.get("chat_template")
|
||||
task = request.to_dict()
|
||||
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
|
||||
self.process_request_dict(task, max_model_len)
|
||||
|
@@ -250,8 +250,8 @@ class DataProcessor:
|
||||
"video",
|
||||
]:
|
||||
image_message_list.append(item)
|
||||
|
||||
prompt_token_ids = self.apply_chat_template(request)
|
||||
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
|
||||
if len(prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
image_start_index = 0
|
||||
@@ -480,7 +480,7 @@ class DataProcessor:
|
||||
break
|
||||
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
|
||||
|
||||
def apply_chat_template(self, request):
|
||||
def apply_chat_template(self, request, **kwargs):
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
|
||||
@@ -498,7 +498,7 @@ class DataProcessor:
|
||||
request,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
chat_template=request.get("chat_template", None),
|
||||
**kwargs,
|
||||
)
|
||||
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
|
||||
"<|video@placeholder|>", ""
|
||||
|
@@ -185,6 +185,9 @@ class DataProcessor(BaseDataProcessor):
|
||||
from paddleformers.trl.llm_utils import get_eos_token_id
|
||||
|
||||
self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config)
|
||||
data_processor_logger.info(
|
||||
f"The eos_token_ids obtained by merging tokenizer and generation_config is {self.eos_token_ids}"
|
||||
)
|
||||
self.eos_token_id_len = len(self.eos_token_ids)
|
||||
self.pad_token_id = self.get_pad_id()
|
||||
self.reasoning_parser = None
|
||||
@@ -205,7 +208,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
str: error message
|
||||
"""
|
||||
data_processor_logger.info(f"Start processing request: {request}")
|
||||
request.chat_template = kwargs.get("chat_template")
|
||||
request = self._apply_default_parameters(request)
|
||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||
request.eos_token_ids = self.eos_token_ids
|
||||
@@ -239,7 +241,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
task = request.to_dict()
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs")
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
||||
if chat_template_kwargs:
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
for k, v in chat_template_kwargs.items():
|
||||
@@ -248,7 +250,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
task.setdefault("enable_thinking", True)
|
||||
request.prompt_token_ids = self.messages2ids(task)
|
||||
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
|
||||
else:
|
||||
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||
|
||||
@@ -313,7 +315,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
elif request.get("messages"):
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||
if chat_template_kwargs:
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
for k, v in chat_template_kwargs.items():
|
||||
@@ -322,7 +324,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
request.setdefault("enable_thinking", True)
|
||||
request["prompt_token_ids"] = self.messages2ids(request)
|
||||
request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||
|
||||
@@ -396,7 +398,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
if token_ids[-1] in self.eos_token_ids:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
if is_end:
|
||||
@@ -434,7 +436,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
if token_ids[-1] in self.eos_token_ids:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
@@ -527,7 +529,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
|
||||
return tokens["input_ids"][0]
|
||||
|
||||
def messages2ids(self, request):
|
||||
def messages2ids(self, request, **kwargs):
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
|
||||
@@ -544,7 +546,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
split_special_tokens=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pd",
|
||||
chat_template=request.get("chat_template", None),
|
||||
**kwargs,
|
||||
)
|
||||
request["text_after_process"] = spliced_message
|
||||
req_id = None
|
||||
|
@@ -392,6 +392,13 @@ class EngineWorkerQueue:
|
||||
llm_logger.debug("get tasks from queue success")
|
||||
return item
|
||||
|
||||
def clear_data(self):
|
||||
self.lock.acquire()
|
||||
self.tasks[:] = list()
|
||||
self.client_read_flag[:] = [1] * self.num_client
|
||||
self.lock.release()
|
||||
llm_logger.info("clear data for engine worker queue")
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Exit the worker queue gracefully.
|
||||
|
@@ -23,7 +23,10 @@ import paddle.nn.layer
|
||||
from paddle.device.cuda import graphs
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||
from fastdeploy.distributed.communication import (
|
||||
capture_custom_allreduce,
|
||||
custom_ar_clear_ipc_handles,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
|
||||
@@ -208,6 +211,7 @@ class CudaGraphPiecewiseBackend:
|
||||
def clear_graph(self):
|
||||
""" """
|
||||
# Clear graphs
|
||||
custom_ar_clear_ipc_handles()
|
||||
for id, entry in self.concrete_size_entries.items():
|
||||
if entry.cuda_graph:
|
||||
del entry.cuda_graph
|
||||
|
@@ -300,6 +300,7 @@ class EPRunner:
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
else:
|
||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
|
@@ -227,13 +227,14 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
|
||||
(
|
||||
|
@@ -490,6 +490,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
|
@@ -262,6 +262,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
||||
|
@@ -32,6 +32,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant
|
||||
|
||||
|
||||
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
@@ -258,8 +259,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
@@ -327,6 +328,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
per_channel_quant=False,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
@@ -379,6 +381,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
per_channel_quant=False,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
@@ -390,6 +393,377 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
return out
|
||||
|
||||
|
||||
class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused wfp8afp8 Quant MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
"""
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"up_gate_proj_weight_scale",
|
||||
"down_proj_weight_scale",
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
self.up_gate_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
1,
|
||||
]
|
||||
self.down_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
1,
|
||||
]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
up_gate_proj_scale_name = self.added_scale_attrs[0]
|
||||
down_proj_scale_name = self.added_scale_attrs[1]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
layer.up_gate_proj_weight.tensor_track = None
|
||||
else:
|
||||
weight_type = "down"
|
||||
layer.down_proj_weight.tensor_track = None
|
||||
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
|
||||
weight_dtype = paddle.float8_e4m3fn
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
||||
scale_dtype = "float32"
|
||||
|
||||
# 2.crate tmp tensor
|
||||
|
||||
weight = paddle.empty(shape=weight_shape, dtype=weight_dtype)
|
||||
scale = paddle.empty(shape=scale_shape, dtype=scale_dtype)
|
||||
|
||||
# 3.quantize weight
|
||||
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
|
||||
|
||||
for expert_id in range(layer.num_experts):
|
||||
weight_quant, scale[expert_id] = per_token_cast_to_fp8(
|
||||
getattr(layer, weight_name)[expert_id].transpose([1, 0]).contiguous(),
|
||||
)
|
||||
weight[expert_id].copy_(weight_quant, False)
|
||||
getattr(layer, weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_shape,
|
||||
dtype=weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype=scale_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(weight, False)
|
||||
getattr(layer, scale_name).copy_(scale, False)
|
||||
|
||||
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
|
||||
"""
|
||||
check layer is valid for this method
|
||||
"""
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4,
|
||||
}
|
||||
if token_num <= E:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4,
|
||||
}
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
|
||||
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
x_q, x_scale = scaled_fp8_quant(x, use_per_token_if_dynamic=True)
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.up_gate_proj_weight,
|
||||
up_gate_proj_out,
|
||||
x_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=moe_intermediate_size * 2,
|
||||
K=hidden_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[2],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[1],
|
||||
stride_cm=up_gate_proj_out.strides[0],
|
||||
stride_cn=up_gate_proj_out.strides[1],
|
||||
#
|
||||
stride_asm=x_scale.strides[0],
|
||||
stride_ask=x_scale.strides[1],
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.up_gate_proj_weight_scale.strides[2],
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
top_k=top_k,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
per_channel_quant=True,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
|
||||
|
||||
down_proj_out = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
|
||||
* ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
x_q, x_scale = scaled_fp8_quant(down_proj_input, use_per_token_if_dynamic=True)
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.down_proj_weight,
|
||||
down_proj_out,
|
||||
x_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_scale.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[2],
|
||||
stride_bn=layer.down_proj_weight.strides[1],
|
||||
stride_cm=down_proj_out.strides[0],
|
||||
stride_cn=down_proj_out.strides[1],
|
||||
stride_asm=x_scale.strides[0],
|
||||
stride_ask=x_scale.strides[1],
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.down_proj_weight_scale.strides[2],
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=True,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
per_channel_quant=True,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
@@ -524,6 +898,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -607,6 +982,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
per_channel_quant=False,
|
||||
even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
@@ -676,6 +1052,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
per_channel_quant=False,
|
||||
even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
@@ -945,6 +1322,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
@@ -1021,6 +1399,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
per_channel_quant=False,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
@@ -1074,6 +1453,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
per_channel_quant=False,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
|
@@ -66,6 +66,7 @@ def get_moe_scores(
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
renormalize: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
@@ -79,6 +80,7 @@ def get_moe_scores(
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
@@ -93,6 +95,7 @@ class FusedMoE(nn.Layer):
|
||||
self,
|
||||
fd_config,
|
||||
reduce_results: bool = True,
|
||||
renormalize: bool = False,
|
||||
moe_intermediate_size: int = -1,
|
||||
num_experts: int = -1,
|
||||
expert_id_offset: int = 0,
|
||||
@@ -119,6 +122,7 @@ class FusedMoE(nn.Layer):
|
||||
self.fd_config = fd_config
|
||||
self.layer_idx = layer_idx
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
|
@@ -59,6 +59,7 @@ def fused_moe_kernel_paddle(
|
||||
compute_type_enum: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_channel_quant: tl.constexpr,
|
||||
even_Ks: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
@@ -121,6 +122,13 @@ def fused_moe_kernel_paddle(
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||
# channel-wise
|
||||
elif per_channel_quant:
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
# Load per-token scale for activations
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
||||
else:
|
||||
# (Zkk): every expert has one activation scale and weight scale.
|
||||
a_scale = tl.load(a_scale_ptr + off_experts)
|
||||
|
@@ -23,6 +23,7 @@ from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||
cutlass_scaled_mm,
|
||||
scaled_fp8_quant,
|
||||
@@ -66,7 +67,14 @@ class WFP8AFP8Config(QuantConfigBase):
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
""" """
|
||||
return WFP8AFP8LinearMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||
Wfp8Afp8MoEMethod,
|
||||
)
|
||||
|
||||
return Wfp8Afp8MoEMethod(self)
|
||||
else:
|
||||
return WFP8AFP8LinearMethod(self)
|
||||
|
||||
|
||||
class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
|
@@ -116,6 +116,7 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
super().__init__()
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.norm_topk_prob = fd_config.model_config.norm_topk_prob
|
||||
|
||||
weight_key_map = {
|
||||
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
||||
@@ -145,6 +146,7 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
self.experts = FusedMoE(
|
||||
fd_config=fd_config,
|
||||
reduce_results=False,
|
||||
renormalize=self.norm_topk_prob,
|
||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||
num_experts=fd_config.model_config.n_routed_experts,
|
||||
top_k=fd_config.model_config.num_experts_per_tok,
|
||||
|
@@ -23,7 +23,7 @@ from paddle import nn
|
||||
from paddle.autograd import PyLayer
|
||||
from paddle.distributed.fleet.utils import recompute
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import _set_var_distributed, get_tensor
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import (
|
||||
RowSequenceParallelLinear,
|
||||
all_gather_group,
|
||||
@@ -197,19 +197,7 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
self.after_norm = RMSNorm(out_config)
|
||||
|
||||
if self.tensor_parallel_degree > 1:
|
||||
for idx in [2, 3]:
|
||||
mark_as_sequence_parallel_parameter(self.spatial_linear[idx].weight)
|
||||
mark_as_sequence_parallel_parameter(self.spatial_linear[idx].bias)
|
||||
_set_var_distributed(self.spatial_linear[idx].weight, split_axis=0)
|
||||
_set_var_distributed(self.spatial_linear[idx].bias, split_axis=0)
|
||||
if self.use_temporal_conv:
|
||||
for idx in [0, 2, 3]:
|
||||
mark_as_sequence_parallel_parameter(self.temporal_linear[idx].weight)
|
||||
mark_as_sequence_parallel_parameter(self.temporal_linear[idx].bias)
|
||||
|
||||
mark_as_sequence_parallel_parameter(self.mlp.weight)
|
||||
mark_as_sequence_parallel_parameter(self.mlp.bias)
|
||||
mark_as_sequence_parallel_parameter(self.after_norm.weight)
|
||||
set_weight_attrs(self.spatial_linear[0].weight, {"output_dim": False})
|
||||
|
||||
def spatial_conv_reshape(self, x, spatial_conv_size):
|
||||
|
@@ -109,6 +109,8 @@ class Glm4Moe(nn.Layer):
|
||||
self.n_routed_experts: int = fd_config.model_config.n_routed_experts
|
||||
self.n_shared_experts: int = fd_config.model_config.n_shared_experts
|
||||
|
||||
self.norm_topk_prob = fd_config.model_config.norm_topk_prob
|
||||
|
||||
weight_key_map = {
|
||||
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||
@@ -133,6 +135,7 @@ class Glm4Moe(nn.Layer):
|
||||
self.experts = FusedMoE(
|
||||
fd_config,
|
||||
reduce_results=False,
|
||||
renormalize=self.norm_topk_prob,
|
||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||
num_experts=fd_config.model_config.n_routed_experts,
|
||||
top_k=fd_config.model_config.num_experts_per_tok,
|
||||
|
@@ -464,6 +464,31 @@ class TokenProcessor:
|
||||
main_process_metrics.request_inference_time.observe(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id])
|
||||
|
||||
def clear_data(self):
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager.clear_data()
|
||||
for i in range(self.cfg.max_num_seqs):
|
||||
if self.resource_manager.stop_flags[i]:
|
||||
continue
|
||||
task = self.resource_manager.tasks_list[i]
|
||||
result = RequestOutput(
|
||||
request_id=task.request_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task.request_id],
|
||||
token_ids=task.eos_token_ids,
|
||||
draft_token_ids=[],
|
||||
),
|
||||
finished=True,
|
||||
metrics=RequestMetrics(
|
||||
arrival_time=time.time(),
|
||||
request_start_time=task.arrival_time,
|
||||
),
|
||||
)
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
self._recycle_resources(task.request_id, i, task, result, is_prefill)
|
||||
llm_logger.warning(f"clear data for task {task.request_id}")
|
||||
|
||||
def _record_speculative_decoding_mertics(self, accept_num):
|
||||
"""Record metrics of speculative decoding"""
|
||||
if not hasattr(main_process_metrics, "spec_decode_draft_acceptance_rate"):
|
||||
|
@@ -66,6 +66,7 @@ class DynamicWeightManager:
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
if not self.first_load:
|
||||
paddle.distributed.restart_process_group()
|
||||
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
|
||||
@@ -115,7 +116,7 @@ class DynamicWeightManager:
|
||||
self._verify_parameters("clearance")
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
@@ -222,12 +223,14 @@ class DynamicWeightManager:
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
||||
logger.info("infer engine stopped! start to load new checkpoint...")
|
||||
model_runner.clear_requests()
|
||||
model_runner.update_parameters(pid)
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
time.sleep(0.01)
|
||||
logger.info("finished loading new checkpoint")
|
||||
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
|
||||
logger.info("infer engine stopped! start to clear checkpoint...")
|
||||
model_runner.clear_requests()
|
||||
model_runner.clear_parameters(pid)
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
|
||||
time.sleep(0.01)
|
||||
|
@@ -1028,12 +1028,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
create_cache_tensor = profile or self.parallel_config.splitwise_role == "mixed"
|
||||
|
||||
if not create_cache_tensor:
|
||||
logger.info("Waiting for cache managers to create kv cache..")
|
||||
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
|
||||
while cache_ready_signal.value[self.local_rank] != 1:
|
||||
time.sleep(1)
|
||||
logger.info("OK! Stop waiting.")
|
||||
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
|
||||
|
||||
logger.info("Initializing kv cache for all layers.")
|
||||
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
|
||||
cache_kvs_list = []
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
@@ -1054,8 +1054,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["caches"] = cache_kvs_list
|
||||
|
||||
if not profile and create_cache_tensor:
|
||||
logger.info("✅ kv cache is ready!")
|
||||
cache_ready_signal.value[self.local_rank] = 1
|
||||
logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}")
|
||||
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
@@ -1704,6 +1704,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.forward_meta.clear_caches()
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
def clear_requests(self):
|
||||
"""Dynamic model loader use to clear requests use for RL"""
|
||||
self.share_inputs["stop_flags"][:] = True
|
||||
|
||||
def clear_parameters(self, pid):
|
||||
"""Dynamic model loader use to clear parameters use for RL"""
|
||||
# Clear CUDAGraph
|
||||
|
@@ -337,6 +337,8 @@ class PaddleDisWorkerProc:
|
||||
self.worker.model_runner,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
)
|
||||
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
|
||||
self.task_queue.clear_data()
|
||||
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
||||
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||
|
||||
|
@@ -115,17 +115,16 @@ def setup_and_run_server():
|
||||
"--max-model-len",
|
||||
"32768",
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"1",
|
||||
"--graph-optimization-config",
|
||||
'{"use_cudagraph":true}',
|
||||
"--load_choices",
|
||||
"default_v1",
|
||||
"--lm_head-fp32",
|
||||
"--quantization",
|
||||
'{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}',
|
||||
"wfp8afp8",
|
||||
]
|
||||
env = os.environ.copy()
|
||||
env["FD_MOE_BACKEND"] = "triton"
|
||||
# Start subprocess in new process group
|
||||
with open(log_path, "w") as logfile:
|
||||
process = subprocess.Popen(
|
||||
|
36
tests/entrypoints/test_engine_client.py
Normal file
36
tests/entrypoints/test_engine_client.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||
|
||||
|
||||
class TestEngineClient(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
# 创建 EngineClient 实例的模拟对象
|
||||
with patch.object(EngineClient, "__init__", return_value=None) as mock_init:
|
||||
self.engine_client = EngineClient("model_path")
|
||||
mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}")
|
||||
|
||||
self.engine_client.data_processor = MagicMock()
|
||||
self.engine_client.zmq_client = MagicMock()
|
||||
self.engine_client.max_model_len = 1024
|
||||
self.engine_client.enable_mm = False
|
||||
|
||||
async def test_add_request(self):
|
||||
request = {
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"prompt_token_ids": [1],
|
||||
"chat_template": "Hello",
|
||||
"max_tokens": 20,
|
||||
"tools": [1],
|
||||
}
|
||||
|
||||
await self.engine_client.add_requests(request)
|
||||
assert "chat_template" in request["chat_template_kwargs"], "'chat_template' not found in 'chat_template_kwargs"
|
||||
assert "tools" in request["chat_template_kwargs"], "'tools' not found in 'chat_template_kwargs'"
|
||||
assert request["chat_template_kwargs"]["chat_template"] == "Hello"
|
||||
assert request["chat_template_kwargs"]["tools"] == [1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -17,6 +17,8 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
self.processor.decode_status = {}
|
||||
self.processor.reasoning_end_dict = {}
|
||||
self.processor.tool_parser_dict = {}
|
||||
self.processor.generation_config = MagicMock()
|
||||
self.processor.eos_token_ids = [1]
|
||||
|
||||
# 模拟 ids2tokens 方法
|
||||
def mock_ids2tokens(token_ids, task_id):
|
||||
@@ -24,6 +26,18 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
|
||||
self.processor.ids2tokens = mock_ids2tokens
|
||||
|
||||
def mock_messages2ids(request, **kwargs):
|
||||
if "chat_template" in kwargs:
|
||||
return [1]
|
||||
else:
|
||||
return [0]
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
self.processor.messages2ids = mock_messages2ids
|
||||
self.processor._apply_default_parameters = mock_apply_default_parameters
|
||||
|
||||
# 模拟推理解析器
|
||||
self.mock_reasoning_parser = MagicMock()
|
||||
self.mock_reasoning_parser.__class__.__name__ = "ErnieX1ReasoningParser"
|
||||
@@ -49,6 +63,17 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
# 验证结果
|
||||
self.assertEqual(result["outputs"]["raw_prediction"], "delta_text")
|
||||
|
||||
def test_process_request_dict(self):
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"chat_template_kwargs": {"chat_template": "Hello!"},
|
||||
"eos_token_ids": [1],
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
}
|
||||
result = self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(result["prompt_token_ids"], [1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
63
tests/input/test_text_processor.py
Normal file
63
tests/input/test_text_processor.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.text_processor import DataProcessor
|
||||
|
||||
|
||||
class TestDataProcessorProcess(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# 创建 DataProcessor 实例的模拟对象
|
||||
with patch.object(DataProcessor, "__init__", return_value=None) as mock_init:
|
||||
self.processor = DataProcessor("model_path")
|
||||
mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}")
|
||||
|
||||
# 设置必要的属性
|
||||
self.processor.tokenizer = MagicMock()
|
||||
self.processor.tokenizer.eos_token_id = 1
|
||||
self.processor.decode_status = {}
|
||||
self.processor.reasoning_end_dict = {}
|
||||
self.processor.tool_parser_dict = {}
|
||||
self.processor.generation_config = MagicMock()
|
||||
self.processor.eos_token_ids = [1]
|
||||
|
||||
def mock_messages2ids(request, **kwargs):
|
||||
if "chat_template" in kwargs:
|
||||
return [1]
|
||||
else:
|
||||
return [0]
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
self.processor.messages2ids = mock_messages2ids
|
||||
self.processor._apply_default_parameters = mock_apply_default_parameters
|
||||
|
||||
def test_process_request(self):
|
||||
request = Request.from_dict(
|
||||
{
|
||||
"request_id": "123",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"eos_token_ids": [1],
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
}
|
||||
)
|
||||
chat_template_kwargs = {"chat_template": "Hello!"}
|
||||
result = self.processor.process_request(request, 100, chat_template_kwargs=chat_template_kwargs)
|
||||
self.assertEqual(result.prompt_token_ids, [1])
|
||||
|
||||
def test_process_request_dict(self):
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"chat_template_kwargs": {"chat_template": "Hello!"},
|
||||
"eos_token_ids": [1],
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
}
|
||||
result = self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(result["prompt_token_ids"], [1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -15,6 +15,7 @@ class TestMoeRouting(unittest.TestCase):
|
||||
self.topk_group = 4
|
||||
self.top_k = 8
|
||||
self.routed_scaling_factor = 1.5
|
||||
self.renormalize = True
|
||||
|
||||
def node_limit_routing(self, gate_probs):
|
||||
"""将所有专家分组, 只在topk_group个group内选择专家"""
|
||||
@@ -64,6 +65,7 @@ class TestMoeRouting(unittest.TestCase):
|
||||
self.topk_group,
|
||||
self.top_k,
|
||||
self.routed_scaling_factor,
|
||||
self.renormalize,
|
||||
)
|
||||
|
||||
ref_topk_values, ref_topk_idx = self.ref_moe_routing()
|
||||
|
@@ -3,15 +3,11 @@ import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
|
||||
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
|
||||
from fastdeploy.input.text_processor import DataProcessor
|
||||
|
||||
|
||||
class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -108,91 +104,6 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
||||
chat_completion = await self.chat_completion_handler.create_chat_completion(request)
|
||||
self.assertEqual("hello", chat_completion["chat_template"])
|
||||
|
||||
@patch("fastdeploy.input.ernie4_5_vl_processor.Ernie4_5_VLProcessor.__init__")
|
||||
def test_ernie4_5_vl_processor(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
ernie4_5_vl_processor = Ernie4_5_VLProcessor()
|
||||
mock_request = Request.from_dict({"request_id": "123"})
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
def mock_process_request(request, max_model_len):
|
||||
return request
|
||||
|
||||
ernie4_5_vl_processor._apply_default_parameters = mock_apply_default_parameters
|
||||
ernie4_5_vl_processor.process_request_dict = mock_process_request
|
||||
result = ernie4_5_vl_processor.process_request(mock_request, chat_template="hello")
|
||||
self.assertEqual("hello", result.chat_template)
|
||||
|
||||
@patch("fastdeploy.input.text_processor.DataProcessor.__init__")
|
||||
def test_text_processor_process_request(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
text_processor = DataProcessor()
|
||||
mock_request = Request.from_dict(
|
||||
{"request_id": "123", "prompt": "hi", "max_tokens": 128, "temperature": 1, "top_p": 1}
|
||||
)
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
def mock_process_request(request, max_model_len):
|
||||
return request
|
||||
|
||||
def mock_text2ids(text, max_model_len):
|
||||
return [1]
|
||||
|
||||
text_processor._apply_default_parameters = mock_apply_default_parameters
|
||||
text_processor.process_request_dict = mock_process_request
|
||||
text_processor.text2ids = mock_text2ids
|
||||
text_processor.eos_token_ids = [1]
|
||||
result = text_processor.process_request(mock_request, chat_template="hello")
|
||||
self.assertEqual("hello", result.chat_template)
|
||||
|
||||
@patch("fastdeploy.input.ernie4_5_processor.Ernie4_5Processor.__init__")
|
||||
def test_ernie4_5_processor_process(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
ernie4_5_processor = Ernie4_5Processor()
|
||||
mock_request = Request.from_dict(
|
||||
{"request_id": "123", "messages": ["hi"], "max_tokens": 128, "temperature": 1, "top_p": 1}
|
||||
)
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
def mock_process_request(request, max_model_len):
|
||||
return request
|
||||
|
||||
def mock_messages2ids(text):
|
||||
return [1]
|
||||
|
||||
ernie4_5_processor._apply_default_parameters = mock_apply_default_parameters
|
||||
ernie4_5_processor.process_request_dict = mock_process_request
|
||||
ernie4_5_processor.messages2ids = mock_messages2ids
|
||||
ernie4_5_processor.eos_token_ids = [1]
|
||||
ernie4_5_processor.reasoning_parser = MagicMock()
|
||||
result = ernie4_5_processor.process_request(mock_request, chat_template="hello")
|
||||
self.assertEqual("hello", result.chat_template)
|
||||
|
||||
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
||||
def test_llm_load(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
llm = LLM()
|
||||
llm.llm_engine = MagicMock()
|
||||
llm.default_sampling_params = MagicMock()
|
||||
llm.chat_template = "hello"
|
||||
|
||||
def mock_run_engine(req_ids, **kwargs):
|
||||
return req_ids
|
||||
|
||||
def mock_add_request(**kwargs):
|
||||
return kwargs.get("chat_template")
|
||||
|
||||
llm._run_engine = mock_run_engine
|
||||
llm._add_request = mock_add_request
|
||||
result = llm.chat(["hello"], sampling_params=SamplingParams(1))
|
||||
self.assertEqual("hello", result)
|
||||
|
||||
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
||||
def test_llm(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
|
Reference in New Issue
Block a user