mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support qk norm (#3145)
This commit is contained in:
@@ -559,3 +559,37 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
|
||||
convert_int8(result, source);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int kWarpSize = 32;
|
||||
|
||||
template<typename T>
|
||||
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
|
||||
*m2 += b_m2;
|
||||
}
|
||||
|
||||
template<typename T, int thread_group_width = kWarpSize>
|
||||
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
|
||||
*m2 = thread_m2;
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
|
||||
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
|
||||
WelfordCombine1(b_m2, m2);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int thread_group_width = kWarpSize>
|
||||
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
|
||||
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T Rsqrt(T x);
|
||||
|
||||
template <>
|
||||
__inline__ __device__ float Rsqrt<float>(float x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ double Rsqrt<double>(double x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user