mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-28 18:51:58 +08:00
[OPs] MoE support wfp8afp8(channelwise) and improve per_token_quant_fp8 (#4238)
This commit is contained in:
@@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) {
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef FP8_E4M3_MAX
|
||||
#define FP8_E4M3_MAX 448.0
|
||||
#endif
|
||||
|
||||
#ifndef DISPATCH_FLOAT_FP6_DTYPE
|
||||
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
|
||||
switch (pd_dtype) { \
|
||||
case phi::DataType::FLOAT32: { \
|
||||
using c_type = float; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::BFLOAT16: { \
|
||||
using c_type = phi::dtype::bfloat16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::FLOAT16: { \
|
||||
using c_type = phi::dtype::float16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1)
|
||||
return num;
|
||||
@@ -573,3 +601,28 @@ inline bool GetMlaUseTensorcore() {
|
||||
flags_mla_use_tensorcore && enable_mla_tensorcore;
|
||||
return mla_use_tensorcore;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user