mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
support w4afp8 EP inference (#3044)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -192,7 +192,8 @@ paddle::Tensor MoeExpertFFNFunc(
|
|||||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||||
const std::string& quant_method, const bool used_in_ep_low_latency);
|
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||||
|
const int estimate_total_token_nums);
|
||||||
|
|
||||||
paddle::Tensor MoeExpertFFNWint2Func(
|
paddle::Tensor MoeExpertFFNWint2Func(
|
||||||
const paddle::Tensor& permute_input,
|
const paddle::Tensor& permute_input,
|
||||||
|
@@ -193,6 +193,12 @@ public:
|
|||||||
typedef uint8_t data_t;
|
typedef uint8_t data_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
|
||||||
|
public:
|
||||||
|
typedef __nv_fp8_e4m3 DataType;
|
||||||
|
typedef paddle::float8_e4m3fn data_t;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
|
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
|
||||||
T val[Size];
|
T val[Size];
|
||||||
|
|
||||||
|
@@ -314,7 +314,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int RoundType = 1>
|
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int Kthread = 512, int RoundType = 1>
|
||||||
__global__ void permute_x_kernel(const T *src_x,
|
__global__ void permute_x_kernel(const T *src_x,
|
||||||
const int64_t *topk_idx,
|
const int64_t *topk_idx,
|
||||||
const float *topk_weights,
|
const float *topk_weights,
|
||||||
@@ -330,9 +330,9 @@ __global__ void permute_x_kernel(const T *src_x,
|
|||||||
int *dst_indices,
|
int *dst_indices,
|
||||||
int *cumsum_idx_gpu,
|
int *cumsum_idx_gpu,
|
||||||
int64_t *token_nums_per_expert_cumsum,
|
int64_t *token_nums_per_expert_cumsum,
|
||||||
int64_t *expert_idx_per_token,
|
int64_t *expert_idx_per_token, // [num_rows, moe_topk]
|
||||||
float max_bound = 127.0,
|
float max_bound = 127.0,
|
||||||
float min_bound = -127.0) { // [num_rows, moe_topk]
|
float min_bound = -127.0) {
|
||||||
const int src_token_idx = blockIdx.x;
|
const int src_token_idx = blockIdx.x;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
constexpr int vec_size = sizeof(int4) / sizeof(T);
|
constexpr int vec_size = sizeof(int4) / sizeof(T);
|
||||||
@@ -375,10 +375,17 @@ __global__ void permute_x_kernel(const T *src_x,
|
|||||||
if (up_gate_proj_in_scale) {
|
if (up_gate_proj_in_scale) {
|
||||||
for (int i = 0; i < vec_size; i++) {
|
for (int i = 0; i < vec_size; i++) {
|
||||||
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
|
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
|
||||||
if (RoundType == 0) {
|
if constexpr (std::is_same<OutT, int8_t>::value) {
|
||||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
|
// w4aint8
|
||||||
|
if (RoundType == 0) {
|
||||||
|
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
|
||||||
|
} else {
|
||||||
|
res_vec[i] = static_cast<OutT>(ClipFunc<float>(round(quant_value), min_bound, max_bound));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
res_vec[i] = static_cast<OutT>(round(quant_value));
|
// w4afp8
|
||||||
|
float value = ClipFunc<float>(quant_value, min_bound, max_bound);
|
||||||
|
res_vec[i] = static_cast<OutT>(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -418,6 +425,10 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
|||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
|
|
||||||
|
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||||
|
typedef typename traits_fp8::DataType DataType_fp8;
|
||||||
|
typedef typename traits_fp8::data_t data_t_fp8;
|
||||||
|
|
||||||
auto stream = input.stream();
|
auto stream = input.stream();
|
||||||
auto place = input.place();
|
auto place = input.place();
|
||||||
const int gridx = min(132 * 8, num_rows);
|
const int gridx = min(132 * 8, num_rows);
|
||||||
@@ -465,6 +476,50 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
|||||||
-127.0
|
-127.0
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
} else if (moe_quant_type == "w4afp8") {
|
||||||
|
if (num_experts_per_rank == 8) {
|
||||||
|
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
|
||||||
|
input.data<data_t>(),
|
||||||
|
topk_ids.data<int64_t>(),
|
||||||
|
topk_weights.data<float>(),
|
||||||
|
token_nums_per_expert.data<int>(),
|
||||||
|
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||||
|
moe_topk,
|
||||||
|
num_rows,
|
||||||
|
token_nums_this_rank,
|
||||||
|
hidden_size,
|
||||||
|
permute_input->data<data_t_fp8>(),
|
||||||
|
permute_indices_per_token->data<int>(),
|
||||||
|
dst_weights->data<float>(),
|
||||||
|
dst_indices->data<int>(),
|
||||||
|
cumsum_idx_gpu->data<int>(),
|
||||||
|
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||||
|
expert_idx_per_token->data<int64_t>(),
|
||||||
|
448.0f,
|
||||||
|
-448.0f
|
||||||
|
);
|
||||||
|
} else if (num_experts_per_rank == 16) {
|
||||||
|
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
|
||||||
|
input.data<data_t>(),
|
||||||
|
topk_ids.data<int64_t>(),
|
||||||
|
topk_weights.data<float>(),
|
||||||
|
token_nums_per_expert.data<int>(),
|
||||||
|
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||||
|
moe_topk,
|
||||||
|
num_rows,
|
||||||
|
token_nums_this_rank,
|
||||||
|
hidden_size,
|
||||||
|
permute_input->data<data_t_fp8>(),
|
||||||
|
permute_indices_per_token->data<int>(),
|
||||||
|
dst_weights->data<float>(),
|
||||||
|
dst_indices->data<int>(),
|
||||||
|
cumsum_idx_gpu->data<int>(),
|
||||||
|
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||||
|
expert_idx_per_token->data<int64_t>(),
|
||||||
|
448.0f,
|
||||||
|
-448.0f
|
||||||
|
);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (num_experts_per_rank == 8) {
|
if (num_experts_per_rank == 8) {
|
||||||
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
||||||
@@ -538,7 +593,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
|||||||
|
|
||||||
auto permute_input = GetEmptyTensor(
|
auto permute_input = GetEmptyTensor(
|
||||||
{token_nums_this_rank, hidden_size},
|
{token_nums_this_rank, hidden_size},
|
||||||
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : input_type,
|
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN : input_type,
|
||||||
place);
|
place);
|
||||||
auto num_experts_per_rank_tensor = GetEmptyTensor(
|
auto num_experts_per_rank_tensor = GetEmptyTensor(
|
||||||
{num_experts_per_rank},
|
{num_experts_per_rank},
|
||||||
|
@@ -88,7 +88,7 @@ struct nv_type_traits<int8_t> {
|
|||||||
constexpr int kLogN = 7; \
|
constexpr int kLogN = 7; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
} else { \
|
} else { \
|
||||||
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupport!", logN)); \
|
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \
|
#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \
|
||||||
@@ -108,7 +108,7 @@ struct nv_type_traits<int8_t> {
|
|||||||
constexpr int VEC_SIZE = 1; \
|
constexpr int VEC_SIZE = 1; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
} else { \
|
} else { \
|
||||||
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupport!", vec_size)); \
|
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", vec_size)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DISPATCH_logN(logN, kLogN, ...) \
|
#define DISPATCH_logN(logN, kLogN, ...) \
|
||||||
@@ -605,26 +605,6 @@ void moe_fast_hardamard_kernel(const T *x,
|
|||||||
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
|
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
|
||||||
}
|
}
|
||||||
if constexpr (kNChunks > 1) {
|
if constexpr (kNChunks > 1) {
|
||||||
// T x_vals_transposed[VecSize][kNChunks] = {init_value};
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int c = 0; c < kNChunks; ++c) {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < VecSize; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; }
|
|
||||||
// }
|
|
||||||
// if constexpr (kNChunks == 28) {
|
|
||||||
// hadamard_mult_thread_chunk_28<VecSize>(x_vals_transposed);
|
|
||||||
// } else if constexpr (kNChunks == 36) {
|
|
||||||
// hadamard_mult_thread_chunk_36<VecSize>(x_vals_transposed);
|
|
||||||
// } else {
|
|
||||||
// constexpr int kLogNChunks = cilog2(kNChunks);
|
|
||||||
// static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
|
|
||||||
// hadamard_mult_thread<kLogNChunks, VecSize>(x_vals_transposed);
|
|
||||||
// }
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int c = 0; c < kNChunks; ++c) {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < VecSize; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; }
|
|
||||||
// }
|
|
||||||
if constexpr (kNChunks == 28) {
|
if constexpr (kNChunks == 28) {
|
||||||
hadamard_mult_thread_28_transpose<T, VecSize>(x_vals);
|
hadamard_mult_thread_28_transpose<T, VecSize>(x_vals);
|
||||||
} else if constexpr (kNChunks == 36) {
|
} else if constexpr (kNChunks == 36) {
|
||||||
|
@@ -72,6 +72,287 @@ __host__ __device__ constexpr static U arrayConvert(T const& input)
|
|||||||
return u;
|
return u;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct uint8 {
|
||||||
|
uint4 u;
|
||||||
|
uint4 v;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int BYTES> struct BytesToType {};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct BytesToType<32> {
|
||||||
|
using Type = uint8;
|
||||||
|
static_assert(sizeof(Type) == 32);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<16> {
|
||||||
|
using Type = uint4;
|
||||||
|
static_assert(sizeof(Type) == 16);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<8> {
|
||||||
|
using Type = uint64_t;
|
||||||
|
static_assert(sizeof(Type) == 8);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<4> {
|
||||||
|
using Type = uint32_t;
|
||||||
|
static_assert(sizeof(Type) == 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<2> {
|
||||||
|
using Type = uint16_t;
|
||||||
|
static_assert(sizeof(Type) == 2);
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BytesToType<1> {
|
||||||
|
using Type = uint8_t;
|
||||||
|
static_assert(sizeof(Type) == 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <template <typename> class ReductionOp, typename T, int block_size>
|
||||||
|
__inline__ __device__ T BlockAllReduce(T val) {
|
||||||
|
typedef cub::BlockReduce<T, block_size> BlockReduce;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
__shared__ T result_broadcast;
|
||||||
|
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
result_broadcast = result;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
return result_broadcast;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SumOp {
|
||||||
|
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType>
|
||||||
|
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
|
||||||
|
const float scale,
|
||||||
|
const float max_bound,
|
||||||
|
const float min_bound) {
|
||||||
|
float quant_value = max_bound * scale * static_cast<float>(input);
|
||||||
|
return static_cast<OutType>(ClipFunc<float>(quant_value, min_bound, max_bound));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename OutT, int VecSize, int Kthread>
|
||||||
|
__global__ void masked_quantize_moe_input_kernel(const T* permuted_inputs,
|
||||||
|
const int64_t* expert_idx_per_token,
|
||||||
|
const float* quant_scales,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const int64_t token_num,
|
||||||
|
const int64_t dim,
|
||||||
|
float* permuted_input_row_sum,
|
||||||
|
const int64_t* recv_expert_count,
|
||||||
|
const int num_max_tokens_per_expert,
|
||||||
|
OutT* out) {
|
||||||
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
|
using LoadOutT = AlignedVector<OutT, VecSize>;
|
||||||
|
LoadT input_vec;
|
||||||
|
LoadOutT output_vec;
|
||||||
|
float scale_factor = -7.0f / 512.0f;
|
||||||
|
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
|
||||||
|
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||||
|
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
|
||||||
|
const auto expert_id = token_idx / num_max_tokens_per_expert;
|
||||||
|
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
|
||||||
|
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
|
||||||
|
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
|
||||||
|
token_idx += num_iters_to_next_expert * gridDim.x;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int64_t expert_idx = expert_idx_per_token[token_idx];
|
||||||
|
float quant_scale = quant_scales[expert_idx];
|
||||||
|
float thread_row_sum = 0.0f;
|
||||||
|
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||||
|
int64_t offset = token_idx * dim + idx * VecSize;
|
||||||
|
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; i++) {
|
||||||
|
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
|
||||||
|
thread_row_sum += static_cast<float>(output_vec[i]);
|
||||||
|
}
|
||||||
|
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
|
||||||
|
}
|
||||||
|
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||||
|
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename OutT, int VecSize, int Kthread>
|
||||||
|
__global__ void quantize_moe_input_kernel(const T* permuted_inputs,
|
||||||
|
const int64_t* expert_idx_per_token,
|
||||||
|
const float* quant_scales,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const int64_t token_num,
|
||||||
|
const int64_t dim,
|
||||||
|
float* permuted_input_row_sum,
|
||||||
|
const int64_t* recv_expert_count,
|
||||||
|
const int num_max_tokens_per_expert,
|
||||||
|
OutT* out) {
|
||||||
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
|
using LoadOutT = AlignedVector<OutT, VecSize>;
|
||||||
|
LoadT input_vec;
|
||||||
|
LoadOutT output_vec;
|
||||||
|
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
|
||||||
|
float scale_factor = -7.0f / 512.0f;
|
||||||
|
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||||
|
int64_t expert_idx = expert_idx_per_token[token_idx];
|
||||||
|
float quant_scale = quant_scales[expert_idx];
|
||||||
|
float thread_row_sum = 0.0f;
|
||||||
|
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||||
|
int64_t offset = token_idx * dim + idx * VecSize;
|
||||||
|
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; i++) {
|
||||||
|
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
|
||||||
|
thread_row_sum += static_cast<float>(output_vec[i]);
|
||||||
|
}
|
||||||
|
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
|
||||||
|
}
|
||||||
|
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||||
|
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename OutT>
|
||||||
|
void quantize_moe_input(
|
||||||
|
const T* permuted_inputs,
|
||||||
|
const int64_t* expert_idx_per_token,
|
||||||
|
const float* quant_scales,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const int64_t token_num,
|
||||||
|
const int64_t dim,
|
||||||
|
float* permuted_input_row_sum,
|
||||||
|
const int64_t* recv_expert_count,
|
||||||
|
const int num_max_tokens_per_expert,
|
||||||
|
bool used_in_ep_low_latency,
|
||||||
|
OutT* out,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
constexpr int VecSize = 16 / sizeof(T);
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
const int dev_id = 0;
|
||||||
|
int sm_count;
|
||||||
|
int act_blocks_per_sm;
|
||||||
|
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||||
|
assert(dim % VecSize == 0);
|
||||||
|
auto kernel = used_in_ep_low_latency ? masked_quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block> : quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block>;
|
||||||
|
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
|
&act_blocks_per_sm, kernel, threads_per_block, 0);
|
||||||
|
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||||
|
dim3 grid;
|
||||||
|
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||||
|
kernel<<<grid, threads_per_block, 0, stream>>>(
|
||||||
|
permuted_inputs,
|
||||||
|
expert_idx_per_token,
|
||||||
|
quant_scales,
|
||||||
|
quant_max_bound,
|
||||||
|
quant_min_bound,
|
||||||
|
token_num,
|
||||||
|
dim,
|
||||||
|
permuted_input_row_sum,
|
||||||
|
recv_expert_count,
|
||||||
|
num_max_tokens_per_expert,
|
||||||
|
out);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int VecSize, int Kthread>
|
||||||
|
__global__ void masked_compute_row_sum_kernel(
|
||||||
|
const T* permuted_inputs,
|
||||||
|
const int64_t token_num,
|
||||||
|
const int64_t dim,
|
||||||
|
float* permuted_input_row_sum,
|
||||||
|
const int64_t* recv_expert_count,
|
||||||
|
const int num_max_tokens_per_expert) {
|
||||||
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
|
LoadT input_vec;
|
||||||
|
float scale_factor = -7.0f / 512.0f;
|
||||||
|
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||||
|
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
|
||||||
|
const auto expert_id = token_idx / num_max_tokens_per_expert;
|
||||||
|
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
|
||||||
|
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
|
||||||
|
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
|
||||||
|
token_idx += num_iters_to_next_expert * gridDim.x;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
float thread_row_sum = 0.0f;
|
||||||
|
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||||
|
int64_t offset = token_idx * dim + idx * VecSize;
|
||||||
|
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; i++) {
|
||||||
|
thread_row_sum += static_cast<float>(input_vec[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||||
|
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int VecSize, int Kthread>
|
||||||
|
__global__ void compute_row_sum_kernel(
|
||||||
|
const T* permuted_inputs,
|
||||||
|
const int64_t token_num,
|
||||||
|
const int64_t dim,
|
||||||
|
float* permuted_input_row_sum,
|
||||||
|
const int64_t* recv_expert_count,
|
||||||
|
const int num_max_tokens_per_expert) {
|
||||||
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
|
LoadT input_vec;
|
||||||
|
float scale_factor = -7.0f / 512.0f;
|
||||||
|
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||||
|
float thread_row_sum = 0.0f;
|
||||||
|
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||||
|
int64_t offset = token_idx * dim + idx * VecSize;
|
||||||
|
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; i++) {
|
||||||
|
thread_row_sum += static_cast<float>(input_vec[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||||
|
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void compute_row_sum(
|
||||||
|
const T* permuted_inputs,
|
||||||
|
const int64_t token_num,
|
||||||
|
const int64_t dim,
|
||||||
|
float* permuted_input_row_sum,
|
||||||
|
const int64_t* recv_expert_count,
|
||||||
|
const int num_max_tokens_per_expert,
|
||||||
|
bool used_in_ep_low_latency,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
constexpr int VecSize = 16 / sizeof(T);
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
const int dev_id = 0;
|
||||||
|
int sm_count;
|
||||||
|
int act_blocks_per_sm;
|
||||||
|
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||||
|
assert(dim % VecSize == 0);
|
||||||
|
auto kernel = used_in_ep_low_latency ? masked_compute_row_sum_kernel<T, VecSize, threads_per_block> : compute_row_sum_kernel<T, VecSize, threads_per_block>;
|
||||||
|
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
|
&act_blocks_per_sm, kernel, threads_per_block, 0);
|
||||||
|
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||||
|
dim3 grid;
|
||||||
|
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||||
|
kernel<<<grid, threads_per_block, 0, stream>>>(
|
||||||
|
permuted_inputs,
|
||||||
|
token_num,
|
||||||
|
dim,
|
||||||
|
permuted_input_row_sum,
|
||||||
|
recv_expert_count,
|
||||||
|
num_max_tokens_per_expert);
|
||||||
|
}
|
||||||
|
|
||||||
// ====================== Softmax things ===============================
|
// ====================== Softmax things ===============================
|
||||||
// We have our own implementation of softmax here so we can support transposing
|
// We have our own implementation of softmax here so we can support transposing
|
||||||
// the output in the softmax kernel when we extend this module to support
|
// the output in the softmax kernel when we extend this module to support
|
||||||
|
@@ -20,6 +20,7 @@
|
|||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "moe/fast_hardamard_kernel.h"
|
#include "moe/fast_hardamard_kernel.h"
|
||||||
#include "moe/fused_moe_helper.h"
|
#include "moe/fused_moe_helper.h"
|
||||||
|
#include "w4afp8_gemm/w4afp8_gemm.h"
|
||||||
|
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
void MoeFFNKernel(const paddle::Tensor& permute_input,
|
void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||||
@@ -33,7 +34,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||||
const std::string& quant_method,
|
const std::string& quant_method,
|
||||||
paddle::Tensor ffn_out,
|
paddle::Tensor ffn_out,
|
||||||
bool used_in_ep_low_latency) {
|
bool used_in_ep_low_latency,
|
||||||
|
const int estimate_total_token_nums) {
|
||||||
using namespace phi;
|
using namespace phi;
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
@@ -60,19 +62,22 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
|
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
|
||||||
Allocator* allocator = paddle::GetAllocator(place);
|
Allocator* allocator = paddle::GetAllocator(place);
|
||||||
Allocator::AllocationPtr workspace;
|
Allocator::AllocationPtr workspace;
|
||||||
if (quant_method == "weight_only_int4" || quant_method == "w4a8") {
|
if (quant_method == "weight_only_int4" || quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||||
inter_dim = inter_dim * 2;
|
inter_dim = inter_dim * 2;
|
||||||
}
|
}
|
||||||
if (quant_method == "w4a8") {
|
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||||
workspace = allocator->Allocate(
|
workspace = allocator->Allocate(
|
||||||
SizeOf(paddle::DataType::INT8) * workspace_size);
|
SizeOf(paddle::DataType::INT8) * workspace_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t inter_size = inter_dim;
|
const int64_t inter_size = inter_dim;
|
||||||
|
|
||||||
|
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||||
|
typedef typename traits_fp8::DataType DataType_fp8;
|
||||||
|
typedef typename traits_fp8::data_t data_t_fp8;
|
||||||
|
|
||||||
int num_experts_ = num_experts;
|
int num_experts_ = num_experts;
|
||||||
int num_max_tokens_per_expert;
|
int num_max_tokens_per_expert = 256;
|
||||||
int expanded_active_expert_rows;
|
int expanded_active_expert_rows;
|
||||||
|
|
||||||
paddle::Tensor fc1_out_tensor;
|
paddle::Tensor fc1_out_tensor;
|
||||||
@@ -161,13 +166,49 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
reinterpret_cast<NvType *>(fc1_out),
|
reinterpret_cast<NvType *>(fc1_out),
|
||||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
total_rows_in_ll_else_minus1,
|
total_rows_in_ll_else_minus1,
|
||||||
tune_total_rows,
|
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
|
||||||
inter_size,
|
inter_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
reinterpret_cast<char*>(workspace->ptr()),
|
reinterpret_cast<char*>(workspace->ptr()),
|
||||||
workspace_size,
|
workspace_size,
|
||||||
num_experts,
|
num_experts,
|
||||||
stream);
|
stream);
|
||||||
|
} else if (quant_method == "w4afp8") {
|
||||||
|
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||||
|
typedef typename traits_fp8::DataType DataType_fp8;
|
||||||
|
typedef typename traits_fp8::data_t data_t_fp8;
|
||||||
|
|
||||||
|
Allocator::AllocationPtr ffn1_input_row_sum;
|
||||||
|
ffn1_input_row_sum = allocator->Allocate(
|
||||||
|
sizeof(float) * expanded_active_expert_rows);
|
||||||
|
|
||||||
|
compute_row_sum(
|
||||||
|
permute_input.data<data_t_fp8>(),
|
||||||
|
expanded_active_expert_rows,
|
||||||
|
hidden_size,
|
||||||
|
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
|
||||||
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
|
num_max_tokens_per_expert,
|
||||||
|
used_in_ep_low_latency,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
|
||||||
|
float* row_scale = nullptr;
|
||||||
|
DisPatchW4AFp8GemmWrapper(
|
||||||
|
reinterpret_cast<const DataType_fp8 *>(permute_input.data<data_t_fp8>()),
|
||||||
|
reinterpret_cast<const DataType_fp8 *>(up_gate_proj_weight.data<int8_t>()),
|
||||||
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
|
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
|
||||||
|
row_scale,
|
||||||
|
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
|
||||||
|
->data<float>(),
|
||||||
|
reinterpret_cast<NvType *>(fc1_out),
|
||||||
|
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||||
|
num_max_tokens_per_expert,
|
||||||
|
num_experts,
|
||||||
|
inter_size,
|
||||||
|
hidden_size,
|
||||||
|
stream);
|
||||||
} else {
|
} else {
|
||||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
||||||
fp16_moe_gemm_runner.moe_gemm_bias_act(
|
fp16_moe_gemm_runner.moe_gemm_bias_act(
|
||||||
@@ -194,7 +235,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||||
}
|
}
|
||||||
auto act_out = act_out_tensor.data<data_t>();
|
auto act_out = act_out_tensor.data<data_t>();
|
||||||
|
|
||||||
if (quant_method == "weight_only_int8") {
|
if (quant_method == "weight_only_int8") {
|
||||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
|
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
|
||||||
int8_moe_gemm_runner.moe_gemm(
|
int8_moe_gemm_runner.moe_gemm(
|
||||||
@@ -267,13 +307,73 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
reinterpret_cast<NvType *>(ffn_out_data),
|
reinterpret_cast<NvType *>(ffn_out_data),
|
||||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
total_rows_in_ll_else_minus1,
|
total_rows_in_ll_else_minus1,
|
||||||
tune_total_rows,
|
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
inter_size / 2,
|
inter_size / 2,
|
||||||
reinterpret_cast<char*>(workspace->ptr()),
|
reinterpret_cast<char*>(workspace->ptr()),
|
||||||
workspace_size,
|
workspace_size,
|
||||||
num_experts,
|
num_experts,
|
||||||
stream);
|
stream);
|
||||||
|
} else if (quant_method == "w4afp8") {
|
||||||
|
data_t *ffn2_shift = nullptr;
|
||||||
|
data_t *ffn2_smooth = nullptr;
|
||||||
|
float* row_scale = nullptr;
|
||||||
|
Allocator::AllocationPtr fp8_act_out;
|
||||||
|
fp8_act_out = allocator->Allocate(
|
||||||
|
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
|
||||||
|
Allocator::AllocationPtr ffn2_input_row_sum;
|
||||||
|
ffn2_input_row_sum = allocator->Allocate(
|
||||||
|
sizeof(float) * expanded_active_expert_rows);
|
||||||
|
|
||||||
|
// note(yuanxiaolan): optimize this
|
||||||
|
MoeFastHardamardWrapper<data_t, data_t>(
|
||||||
|
act_out_tensor.data<data_t>(),
|
||||||
|
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
|
||||||
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
|
ffn2_shift, // ffn2_shift->data<T>(),
|
||||||
|
ffn2_smooth, // ffn2_smooth->data<T>(),
|
||||||
|
nullptr,
|
||||||
|
1,
|
||||||
|
448.0f,
|
||||||
|
-448.0f,
|
||||||
|
expanded_active_expert_rows,
|
||||||
|
inter_size / 2,
|
||||||
|
num_max_tokens_per_expert,
|
||||||
|
used_in_ep_low_latency,
|
||||||
|
act_out_tensor.data<data_t>(),
|
||||||
|
stream
|
||||||
|
);
|
||||||
|
|
||||||
|
quantize_moe_input<data_t, data_t_fp8>(act_out_tensor.data<data_t>(),
|
||||||
|
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
|
||||||
|
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
|
||||||
|
448.0f,
|
||||||
|
-448.0f,
|
||||||
|
expanded_active_expert_rows,
|
||||||
|
inter_size / 2,
|
||||||
|
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
|
||||||
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
|
num_max_tokens_per_expert,
|
||||||
|
used_in_ep_low_latency,
|
||||||
|
reinterpret_cast<data_t_fp8 *>(fp8_act_out->ptr()),
|
||||||
|
stream
|
||||||
|
);
|
||||||
|
|
||||||
|
DisPatchW4AFp8GemmWrapper(
|
||||||
|
reinterpret_cast<const DataType_fp8 *>(fp8_act_out->ptr()),
|
||||||
|
reinterpret_cast<const DataType_fp8 *>(down_proj_weight.data<int8_t>()),
|
||||||
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
|
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
|
||||||
|
row_scale,
|
||||||
|
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
|
||||||
|
->data<float>(),
|
||||||
|
reinterpret_cast<NvType*>(ffn_out_data),
|
||||||
|
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||||
|
num_max_tokens_per_expert,
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
inter_size / 2,
|
||||||
|
stream);
|
||||||
} else {
|
} else {
|
||||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
||||||
fp16_moe_gemm_runner.moe_gemm(
|
fp16_moe_gemm_runner.moe_gemm(
|
||||||
@@ -302,10 +402,12 @@ paddle::Tensor MoeExpertFFNFunc(
|
|||||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||||
const std::string& quant_method, const bool used_in_ep_low_latency) {
|
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||||
|
const int estimate_total_token_nums) {
|
||||||
|
|
||||||
cudaCheckError();
|
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
|
||||||
const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype();
|
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||||
|
permute_input.dtype();
|
||||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||||
|
|
||||||
switch (t_type) {
|
switch (t_type) {
|
||||||
@@ -320,7 +422,9 @@ paddle::Tensor MoeExpertFFNFunc(
|
|||||||
down_proj_in_scale,
|
down_proj_in_scale,
|
||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
quant_method,
|
quant_method,
|
||||||
ffn_out, used_in_ep_low_latency);
|
ffn_out,
|
||||||
|
used_in_ep_low_latency,
|
||||||
|
estimate_total_token_nums);
|
||||||
break;
|
break;
|
||||||
case paddle::DataType::FLOAT16:
|
case paddle::DataType::FLOAT16:
|
||||||
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||||
@@ -333,7 +437,9 @@ paddle::Tensor MoeExpertFFNFunc(
|
|||||||
down_proj_in_scale,
|
down_proj_in_scale,
|
||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
quant_method,
|
quant_method,
|
||||||
ffn_out, used_in_ep_low_latency);
|
ffn_out,
|
||||||
|
used_in_ep_low_latency,
|
||||||
|
estimate_total_token_nums);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||||
@@ -351,7 +457,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
|||||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||||
const std::string& quant_method, const bool used_in_ep_low_latency) {
|
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||||
|
const int estimate_total_token_nums) {
|
||||||
return {MoeExpertFFNFunc(permute_input,
|
return {MoeExpertFFNFunc(permute_input,
|
||||||
tokens_expert_prefix_sum,
|
tokens_expert_prefix_sum,
|
||||||
up_gate_proj_weight,
|
up_gate_proj_weight,
|
||||||
@@ -361,7 +468,9 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
|||||||
down_proj_scale,
|
down_proj_scale,
|
||||||
down_proj_in_scale,
|
down_proj_in_scale,
|
||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
quant_method, used_in_ep_low_latency)};
|
quant_method,
|
||||||
|
used_in_ep_low_latency,
|
||||||
|
estimate_total_token_nums)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||||
@@ -375,7 +484,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
|||||||
const paddle::optional<std::vector<int64_t>>& down_proj_in_scale_shape,
|
const paddle::optional<std::vector<int64_t>>& down_proj_in_scale_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||||
const std::string& quant_method,
|
const std::string& quant_method,
|
||||||
const bool used_in_ep_low_latency) {
|
const bool used_in_ep_low_latency,
|
||||||
|
const int estimate_total_token_nums) {
|
||||||
return {permute_input_shape};
|
return {permute_input_shape};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,8 +498,9 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
|||||||
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
|
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
|
||||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
||||||
const std::string &quant_method, const bool used_in_ep_low_latency) {
|
const std::string &quant_method, const bool used_in_ep_low_latency,
|
||||||
if (quant_method == "w4a8") {
|
const int estimate_total_token_nums) {
|
||||||
|
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||||
return {up_gate_proj_scale_dtype.get()};
|
return {up_gate_proj_scale_dtype.get()};
|
||||||
} else {
|
} else {
|
||||||
return {permute_input_dtype};
|
return {permute_input_dtype};
|
||||||
@@ -460,7 +571,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
|||||||
paddle::Optional("down_proj_in_scale"),
|
paddle::Optional("down_proj_in_scale"),
|
||||||
paddle::Optional("expert_idx_per_token")})
|
paddle::Optional("expert_idx_per_token")})
|
||||||
.Outputs({"output_tensor"})
|
.Outputs({"output_tensor"})
|
||||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool"})
|
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
|
||||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
||||||
|
@@ -103,7 +103,7 @@ struct CollectiveMainloopFwd {
|
|||||||
LayoutT layout_C;
|
LayoutT layout_C;
|
||||||
const float *weight_scale;
|
const float *weight_scale;
|
||||||
const float *input_row_sum;
|
const float *input_row_sum;
|
||||||
const int * tokens;
|
const int64_t * tokens;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
@@ -114,7 +114,7 @@ struct CollectiveMainloopFwd {
|
|||||||
ElementOutput * ptr_C;
|
ElementOutput * ptr_C;
|
||||||
const float *weight_scale;
|
const float *weight_scale;
|
||||||
const float *input_row_sum;
|
const float *input_row_sum;
|
||||||
const int * tokens;
|
const int64_t * tokens;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -153,8 +153,8 @@ struct CollectiveMainloopFwd {
|
|||||||
TiledMma tiled_mma,
|
TiledMma tiled_mma,
|
||||||
const float *input_row_sum,
|
const float *input_row_sum,
|
||||||
const float *weight_scale,
|
const float *weight_scale,
|
||||||
const int tokens,
|
const int64_t tokens,
|
||||||
const int pre_fix_tokens,
|
const int64_t pre_fix_tokens,
|
||||||
const int bidm,
|
const int bidm,
|
||||||
const int bidn,
|
const int bidn,
|
||||||
const int bidb,
|
const int bidb,
|
||||||
|
@@ -19,6 +19,7 @@
|
|||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
#include "w4afp8_gemm_template.h"
|
#include "w4afp8_gemm_template.h"
|
||||||
|
#include "w4afp8_gemm.h"
|
||||||
|
|
||||||
|
|
||||||
void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
|
void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
|
||||||
@@ -39,7 +40,22 @@ void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T> class NVTraits;
|
||||||
|
|
||||||
|
template <> class NVTraits<__nv_fp8_e4m3> {
|
||||||
|
public:
|
||||||
|
typedef cutlass::float_e4m3_t data_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> class NVTraits<__nv_bfloat16>{
|
||||||
|
public:
|
||||||
|
typedef cutlass::bfloat16_t data_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> class NVTraits<half>{
|
||||||
|
public:
|
||||||
|
typedef cutlass::half_t data_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -48,15 +64,15 @@ template <typename OutputType>
|
|||||||
void DisPatchW4AFp8Gemm(
|
void DisPatchW4AFp8Gemm(
|
||||||
const cutlass::float_e4m3_t* input,
|
const cutlass::float_e4m3_t* input,
|
||||||
const cutlass::float_e4m3_t* weight,
|
const cutlass::float_e4m3_t* weight,
|
||||||
const int * tokens,
|
const int64_t * tokens,
|
||||||
const float * input_row_sum,
|
const float * input_row_sum,
|
||||||
const float * weight_scale,
|
const float * weight_scale,
|
||||||
OutputType * out,
|
OutputType * out,
|
||||||
const int token_padding_size,
|
const int64_t token_padding_size,
|
||||||
const int max_tokens,
|
const int64_t max_tokens,
|
||||||
const int batch_size,
|
const int batch_size,
|
||||||
const int M,
|
const int64_t M,
|
||||||
const int K,
|
const int64_t K,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
|
|
||||||
int kBlockN = (max_tokens + 15) / 16 * 16;
|
int kBlockN = (max_tokens + 15) / 16 * 16;
|
||||||
@@ -87,9 +103,10 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
|||||||
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
||||||
const paddle::Tensor& input_row_sum,
|
const paddle::Tensor& input_row_sum,
|
||||||
const paddle::Tensor& weight_scale,
|
const paddle::Tensor& weight_scale,
|
||||||
const int token_padding_size,
|
const int64_t token_padding_size,
|
||||||
const int max_tokens,
|
const int64_t max_tokens,
|
||||||
const bool is_bflot16) {
|
const bool is_bfloat16) {
|
||||||
|
|
||||||
|
|
||||||
const int batch_size = weight.dims()[0];
|
const int batch_size = weight.dims()[0];
|
||||||
const int M = weight.dims()[1];
|
const int M = weight.dims()[1];
|
||||||
@@ -101,13 +118,13 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
|||||||
|
|
||||||
if (token_padding_size == 0) {
|
if (token_padding_size == 0) {
|
||||||
const int all_tokens = input.dims()[0];
|
const int all_tokens = input.dims()[0];
|
||||||
if (is_bflot16) {
|
if (is_bfloat16) {
|
||||||
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
|
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
|
||||||
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
|
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
|
||||||
DisPatchW4AFp8Gemm(
|
DisPatchW4AFp8Gemm(
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||||
tokens.data<int>(),
|
tokens.data<int64_t>(),
|
||||||
input_row_sum.data<float>(),
|
input_row_sum.data<float>(),
|
||||||
weight_scale.data<float>(),
|
weight_scale.data<float>(),
|
||||||
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
||||||
@@ -122,13 +139,13 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
|||||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (is_bflot16) {
|
if (is_bfloat16) {
|
||||||
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
|
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
|
||||||
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
|
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
|
||||||
DisPatchW4AFp8Gemm(
|
DisPatchW4AFp8Gemm(
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||||
tokens.data<int>(),
|
tokens.data<int64_t>(),
|
||||||
input_row_sum.data<float>(),
|
input_row_sum.data<float>(),
|
||||||
weight_scale.data<float>(),
|
weight_scale.data<float>(),
|
||||||
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
||||||
@@ -145,6 +162,38 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename InputType, typename OutputType>
|
||||||
|
void DisPatchW4AFp8GemmWrapper(
|
||||||
|
const InputType* input,
|
||||||
|
const InputType* weight,
|
||||||
|
const int64_t* total_rows_before_expert,
|
||||||
|
const float* input_row_sum,
|
||||||
|
const float* row_scale,
|
||||||
|
const float* weight_scale,
|
||||||
|
OutputType * out,
|
||||||
|
const int64_t token_padding_size,
|
||||||
|
const int64_t max_tokens,
|
||||||
|
const int num_experts,
|
||||||
|
const int64_t M,
|
||||||
|
const int64_t K,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
using InType = typename NVTraits<InputType>::data_t;
|
||||||
|
using OutType = typename NVTraits<OutputType>::data_t;
|
||||||
|
DisPatchW4AFp8Gemm(
|
||||||
|
reinterpret_cast<const InType*>(input),
|
||||||
|
reinterpret_cast<const InType*>(weight),
|
||||||
|
total_rows_before_expert,
|
||||||
|
input_row_sum,
|
||||||
|
weight_scale,
|
||||||
|
reinterpret_cast<OutType*>(out),
|
||||||
|
token_padding_size,
|
||||||
|
max_tokens,
|
||||||
|
num_experts,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
|
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
|
||||||
const int batch_size = weight.dims()[0];
|
const int batch_size = weight.dims()[0];
|
||||||
@@ -155,6 +204,63 @@ std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight
|
|||||||
return {weight_new};
|
return {weight_new};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int kPackSize>
|
||||||
|
__global__ void permute_scale_kernel(
|
||||||
|
T* input_data,
|
||||||
|
const int numel) {
|
||||||
|
using LoadT = AlignedVector<T, kPackSize>;
|
||||||
|
LoadT input_vec;
|
||||||
|
LoadT dst_vec;
|
||||||
|
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
|
||||||
|
if (load_idx >= numel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
|
||||||
|
|
||||||
|
for (int i = 0; i < kPackSize; i+=2) {
|
||||||
|
dst_vec[i] = input_vec[i / 2];
|
||||||
|
dst_vec[i + 1] = input_vec[i / 2 + 8];
|
||||||
|
}
|
||||||
|
|
||||||
|
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
||||||
|
const int row = scale.dims()[0];
|
||||||
|
const int col = scale.dims()[1];
|
||||||
|
if (col % 16 != 0) {
|
||||||
|
PD_THROW("Only supported when col is divisible by 16.");
|
||||||
|
}
|
||||||
|
const int numel = row * col;
|
||||||
|
const int threads = 128;
|
||||||
|
const int kPackSize = 16;
|
||||||
|
const int grid_size = (numel / kPackSize + threads - 1) / threads;
|
||||||
|
|
||||||
|
if (scale.dtype() == paddle::DataType::BFLOAT16) {
|
||||||
|
permute_scale_kernel<phi::dtype::bfloat16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||||
|
const_cast<phi::dtype::bfloat16*>(scale.data<phi::dtype::bfloat16>()),
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
|
||||||
|
permute_scale_kernel<phi::dtype::float16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||||
|
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
|
||||||
|
permute_scale_kernel<float, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||||
|
const_cast<float*>(scale.data<float>()),
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute)
|
||||||
|
.Inputs({"weight_scale"})
|
||||||
|
.Outputs({"permute_scale"})
|
||||||
|
.SetInplaceMap({{"weight_scale", "permute_scale"}})
|
||||||
|
.SetKernelFn(PD_KERNEL(W4AFp8GemmScalePermute));
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(w4afp8_gemm)
|
PD_BUILD_STATIC_OP(w4afp8_gemm)
|
||||||
.Inputs({"input",
|
.Inputs({"input",
|
||||||
"weight",
|
"weight",
|
||||||
@@ -162,12 +268,44 @@ PD_BUILD_STATIC_OP(w4afp8_gemm)
|
|||||||
"input_row_sum",
|
"input_row_sum",
|
||||||
"weight_scale"})
|
"weight_scale"})
|
||||||
.Outputs({"out"})
|
.Outputs({"out"})
|
||||||
.Attrs({"token_padding_size: int",
|
.Attrs({"token_padding_size: int64_t",
|
||||||
"max_tokens: int",
|
"max_tokens: int64_t",
|
||||||
"is_bflot16: bool"})
|
"is_bfloat16: bool"})
|
||||||
.SetKernelFn(PD_KERNEL(W4AFp8Gemm));
|
.SetKernelFn(PD_KERNEL(W4AFp8Gemm));
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert)
|
PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert)
|
||||||
.Inputs({"weight"})
|
.Inputs({"weight"})
|
||||||
.Outputs({"converted_weight"})
|
.Outputs({"converted_weight"})
|
||||||
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));
|
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));
|
||||||
|
|
||||||
|
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>(
|
||||||
|
const __nv_fp8_e4m3* input,
|
||||||
|
const __nv_fp8_e4m3* weight,
|
||||||
|
const int64_t * tokens,
|
||||||
|
const float * input_row_sum,
|
||||||
|
const float * row_scale,
|
||||||
|
const float * weight_scale,
|
||||||
|
__nv_bfloat16 * out,
|
||||||
|
const int64_t token_padding_size,
|
||||||
|
const int64_t max_tokens,
|
||||||
|
const int num_experts,
|
||||||
|
const int64_t M,
|
||||||
|
const int64_t K,
|
||||||
|
cudaStream_t stream
|
||||||
|
);
|
||||||
|
|
||||||
|
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>(
|
||||||
|
const __nv_fp8_e4m3* input,
|
||||||
|
const __nv_fp8_e4m3* weight,
|
||||||
|
const int64_t * tokens,
|
||||||
|
const float * input_row_sum,
|
||||||
|
const float * row_scale,
|
||||||
|
const float * weight_scale,
|
||||||
|
half * out,
|
||||||
|
const int64_t token_padding_size,
|
||||||
|
const int64_t max_tokens,
|
||||||
|
const int num_experts,
|
||||||
|
const int64_t M,
|
||||||
|
const int64_t K,
|
||||||
|
cudaStream_t stream
|
||||||
|
);
|
||||||
|
47
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h
Normal file
47
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||||
|
const paddle::Tensor& input,
|
||||||
|
const paddle::Tensor& weight,
|
||||||
|
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
||||||
|
const paddle::Tensor& input_row_sum,
|
||||||
|
const paddle::Tensor& weight_scale,
|
||||||
|
const int64_t token_padding_size,
|
||||||
|
const int64_t max_tokens,
|
||||||
|
const bool is_bfloat16);
|
||||||
|
|
||||||
|
template <typename InputType, typename OutputType>
|
||||||
|
void DisPatchW4AFp8GemmWrapper(
|
||||||
|
const InputType* input,
|
||||||
|
const InputType* weight,
|
||||||
|
const int64_t * tokens,
|
||||||
|
const float * input_row_sum,
|
||||||
|
const float * row_scale,
|
||||||
|
const float * weight_scale,
|
||||||
|
OutputType * out,
|
||||||
|
const int64_t token_padding_size,
|
||||||
|
const int64_t max_tokens,
|
||||||
|
const int num_experts,
|
||||||
|
const int64_t M,
|
||||||
|
const int64_t K,
|
||||||
|
cudaStream_t stream);
|
@@ -27,7 +27,7 @@
|
|||||||
#include "mainloop_fwd.h"
|
#include "mainloop_fwd.h"
|
||||||
|
|
||||||
template <typename Ktraits>
|
template <typename Ktraits>
|
||||||
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_geem_kernel(
|
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_gemm_kernel(
|
||||||
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
|
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
|
||||||
|
|
||||||
using Element = typename Ktraits::Element;
|
using Element = typename Ktraits::Element;
|
||||||
@@ -87,9 +87,9 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int pre_fix_tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] : 0;
|
const int pre_fix_tokens = TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1]) : 0;
|
||||||
|
|
||||||
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb + 1] - pre_fix_tokens : mainloop_params.tokens[bidb];
|
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] - pre_fix_tokens : mainloop_params.tokens[bidb];
|
||||||
|
|
||||||
|
|
||||||
if (bidn * kBlockN >= tokens) {
|
if (bidn * kBlockN >= tokens) {
|
||||||
@@ -207,7 +207,7 @@ auto get_gmem_layout(const int Rows, const int Cols) {
|
|||||||
|
|
||||||
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
|
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
|
||||||
void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
|
void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
|
||||||
const float *input_row_sum, const int * tokens, const int max_tokens, cudaStream_t stream) {
|
const float *input_row_sum, const int64_t * tokens, const int64_t max_tokens, cudaStream_t stream) {
|
||||||
|
|
||||||
using ElementOutput = typename Kernel_traits::ElementOutput;
|
using ElementOutput = typename Kernel_traits::ElementOutput;
|
||||||
using Element = typename Kernel_traits::Element;
|
using Element = typename Kernel_traits::Element;
|
||||||
@@ -231,7 +231,7 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl
|
|||||||
});
|
});
|
||||||
|
|
||||||
void *kernel;
|
void *kernel;
|
||||||
kernel = (void *)w4afp8_geem_kernel<Kernel_traits>;
|
kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>;
|
||||||
|
|
||||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
|
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
|
||||||
|
|
||||||
|
@@ -36,8 +36,8 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
|||||||
{cutlass_type} * out,
|
{cutlass_type} * out,
|
||||||
const float *weight_scale,
|
const float *weight_scale,
|
||||||
const float *input_row_sum,
|
const float *input_row_sum,
|
||||||
const int *tokens,
|
const int64_t *tokens,
|
||||||
const int max_tokens,
|
const int64_t max_tokens,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -54,8 +54,8 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
|||||||
{cutlass_type} * out,
|
{cutlass_type} * out,
|
||||||
const float *weight_scale,
|
const float *weight_scale,
|
||||||
const float *input_row_sum,
|
const float *input_row_sum,
|
||||||
const int *tokens,
|
const int64_t *tokens,
|
||||||
const int max_tokens,
|
const int64_t max_tokens,
|
||||||
cudaStream_t stream) {{
|
cudaStream_t stream) {{
|
||||||
|
|
||||||
constexpr static int M = {M};
|
constexpr static int M = {M};
|
||||||
|
@@ -12,13 +12,18 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .fused_moe_cutlass_backend import CutlassW4A8MoEMethod, CutlassWeightOnlyMoEMethod
|
from .fused_moe_cutlass_backend import (
|
||||||
|
CutlassW4A8MoEMethod,
|
||||||
|
CutlassW4AFP8MoEMethod,
|
||||||
|
CutlassWeightOnlyMoEMethod,
|
||||||
|
)
|
||||||
from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod
|
from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod
|
||||||
from .moe import FusedMoE
|
from .moe import FusedMoE
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
CutlassWeightOnlyMoEMethod,
|
CutlassWeightOnlyMoEMethod,
|
||||||
CutlassW4A8MoEMethod,
|
CutlassW4A8MoEMethod,
|
||||||
|
CutlassW4AFP8MoEMethod,
|
||||||
FusedMoE,
|
FusedMoE,
|
||||||
TritonWeightOnlyMoEMethod,
|
TritonWeightOnlyMoEMethod,
|
||||||
]
|
]
|
||||||
|
@@ -389,7 +389,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
):
|
):
|
||||||
(
|
(
|
||||||
num_tokens_per_rank,
|
num_tokens_per_rank,
|
||||||
_,
|
num_tokens_per_rdma_rank,
|
||||||
num_tokens_per_expert,
|
num_tokens_per_expert,
|
||||||
is_token_in_rank,
|
is_token_in_rank,
|
||||||
_,
|
_,
|
||||||
@@ -399,6 +399,7 @@ class EPPrefillRunner(EPRunner):
|
|||||||
dispatch_args = {
|
dispatch_args = {
|
||||||
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
|
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
|
||||||
"num_tokens_per_rank": num_tokens_per_rank,
|
"num_tokens_per_rank": num_tokens_per_rank,
|
||||||
|
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||||
"is_token_in_rank": is_token_in_rank,
|
"is_token_in_rank": is_token_in_rank,
|
||||||
"num_tokens_per_expert": num_tokens_per_expert,
|
"num_tokens_per_expert": num_tokens_per_expert,
|
||||||
"config": self.ep_engine.ep_config,
|
"config": self.ep_engine.ep_config,
|
||||||
|
@@ -31,6 +31,7 @@ if current_platform.is_cuda():
|
|||||||
moe_expert_dispatch,
|
moe_expert_dispatch,
|
||||||
moe_expert_reduce,
|
moe_expert_reduce,
|
||||||
noaux_tc,
|
noaux_tc,
|
||||||
|
w4afp8_gemm_scale_permute,
|
||||||
)
|
)
|
||||||
elif current_platform.is_iluvatar():
|
elif current_platform.is_iluvatar():
|
||||||
from fastdeploy.model_executor.ops.iluvatar import (
|
from fastdeploy.model_executor.ops.iluvatar import (
|
||||||
@@ -87,6 +88,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
token_nums_per_expert: paddle.Tensor,
|
token_nums_per_expert: paddle.Tensor,
|
||||||
expert_idx_per_token: paddle.Tensor,
|
expert_idx_per_token: paddle.Tensor,
|
||||||
used_in_ep_low_latency: bool = False,
|
used_in_ep_low_latency: bool = False,
|
||||||
|
estimate_total_token_nums: int = -1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
@@ -104,6 +106,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
self.moe_quant_type,
|
self.moe_quant_type,
|
||||||
used_in_ep_low_latency,
|
used_in_ep_low_latency,
|
||||||
|
estimate_total_token_nums,
|
||||||
)
|
)
|
||||||
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
||||||
permute_input,
|
permute_input,
|
||||||
@@ -117,6 +120,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
self.moe_quant_type,
|
self.moe_quant_type,
|
||||||
used_in_ep_low_latency,
|
used_in_ep_low_latency,
|
||||||
|
estimate_total_token_nums,
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_ep_prefill(
|
def apply_ep_prefill(
|
||||||
@@ -157,13 +161,13 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
(self.up_gate_proj_in_scale if hasattr(self, "up_gate_proj_in_scale") else None),
|
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
|
||||||
recv_num_tokens_per_expert_list,
|
recv_num_tokens_per_expert_list,
|
||||||
token_all_num,
|
token_all_num,
|
||||||
self.moe_quant_type,
|
self.moe_quant_type,
|
||||||
)
|
)
|
||||||
if self.moe_quant_type != "w4a8":
|
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||||
# only w4a8 need expert_idx_per_token
|
# only w4a8 and w4afp8 need expert_idx_per_token
|
||||||
# Other need not this tensor, so we make it None.
|
# Other need not this tensor, so we make it None.
|
||||||
expert_idx_per_token = None
|
expert_idx_per_token = None
|
||||||
else:
|
else:
|
||||||
@@ -202,18 +206,19 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
Apply the EP decoder method.
|
Apply the EP decoder method.
|
||||||
"""
|
"""
|
||||||
gate_out = gate(x.cast("float32"))
|
gate_out = gate(x.cast("float32"))
|
||||||
|
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
|
||||||
# 1. Select topk experts and weights
|
# 1. Select topk experts and weights
|
||||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||||
expertwise_scale = None
|
expertwise_scale = None
|
||||||
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
|
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
|
||||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||||
|
use_fp8 = self.moe_quant_type == "w4afp8"
|
||||||
# 2. EP Dispatch
|
# 2. EP Dispatch
|
||||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||||
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale
|
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale, use_fp8=use_fp8
|
||||||
)
|
)
|
||||||
# 3. Compute ffn
|
# 3. Compute ffn
|
||||||
if self.moe_quant_type == "w4a8":
|
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
|
||||||
num_local_experts, max_num, _ = permute_input.shape
|
num_local_experts, max_num, _ = permute_input.shape
|
||||||
expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile([1, max_num])
|
expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile([1, max_num])
|
||||||
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
|
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
|
||||||
@@ -227,6 +232,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
token_nums_per_expert.cast("int64"),
|
token_nums_per_expert.cast("int64"),
|
||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
True,
|
True,
|
||||||
|
estimate_total_token_nums,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. EP combine
|
# 4. EP combine
|
||||||
@@ -290,7 +296,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
topk_only_mode=False,
|
topk_only_mode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.moe_quant_type != "w4a8":
|
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||||
# only w4a8 need expert_idx_per_token
|
# only w4a8 need expert_idx_per_token
|
||||||
# Other need not this tensor, so we make it None.
|
# Other need not this tensor, so we make it None.
|
||||||
expert_idx_per_token = None
|
expert_idx_per_token = None
|
||||||
@@ -373,9 +379,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
||||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
||||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0)
|
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).unsqueeze()
|
||||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0)
|
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).unsqueeze()
|
||||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0)
|
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).unsqueeze()
|
||||||
|
|
||||||
name_tensor_map = {
|
name_tensor_map = {
|
||||||
"up_gate_proj_weight": up_gate_proj_weight,
|
"up_gate_proj_weight": up_gate_proj_weight,
|
||||||
@@ -448,7 +454,6 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
Args:
|
Args:
|
||||||
layer (nn.Layer): The layer to add parameters to.
|
layer (nn.Layer): The layer to add parameters to.
|
||||||
weight_key_map (dict): The weight key map.
|
weight_key_map (dict): The weight key map.
|
||||||
state_dict (dict): The state dict.
|
|
||||||
"""
|
"""
|
||||||
self.default_dtype = layer._helper.get_default_dtype()
|
self.default_dtype = layer._helper.get_default_dtype()
|
||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
@@ -572,6 +577,263 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||||
|
"""
|
||||||
|
w4a8 MoE Method
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config):
|
||||||
|
super().__init__(quant_config)
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.moe_quant_type = "w4afp8"
|
||||||
|
self.pack_num = 2
|
||||||
|
|
||||||
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||||
|
"""
|
||||||
|
Paddle cutlass process prequanted weights.
|
||||||
|
"""
|
||||||
|
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
|
||||||
|
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
|
||||||
|
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||||
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||||
|
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||||
|
|
||||||
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||||
|
layer.load_experts_weight(
|
||||||
|
state_dict,
|
||||||
|
up_gate_proj_expert_weight_key,
|
||||||
|
down_proj_expert_weight_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
up_gate_proj_weight_scale = []
|
||||||
|
down_proj_weight_scale = []
|
||||||
|
up_gate_proj_in_scale_all_experts = []
|
||||||
|
up_gate_proj_in_scale = []
|
||||||
|
down_proj_in_scale = []
|
||||||
|
|
||||||
|
if layer.ep_size > 1:
|
||||||
|
for expert_idx in ep_rank_to_expert_id_list:
|
||||||
|
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
||||||
|
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||||
|
|
||||||
|
for expert_idx in logical_expert_ids:
|
||||||
|
up_gate_proj_weight_scale.append(
|
||||||
|
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
||||||
|
)
|
||||||
|
down_proj_weight_scale.append(
|
||||||
|
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
||||||
|
)
|
||||||
|
up_gate_proj_in_scale.append(
|
||||||
|
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
|
||||||
|
)
|
||||||
|
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
|
||||||
|
|
||||||
|
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||||
|
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||||
|
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||||
|
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||||
|
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
|
||||||
|
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
|
||||||
|
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
|
||||||
|
|
||||||
|
name_tensor_map = {
|
||||||
|
"up_gate_proj_weight": up_gate_proj_weight,
|
||||||
|
"down_proj_weight": down_proj_weight,
|
||||||
|
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||||
|
"down_proj_weight_scale": down_proj_weight_scale,
|
||||||
|
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
|
||||||
|
"up_gate_proj_in_scale": up_gate_proj_in_scale,
|
||||||
|
"down_proj_in_scale": down_proj_in_scale,
|
||||||
|
}
|
||||||
|
for name, tensor in name_tensor_map.items():
|
||||||
|
getattr(layer, name).set_value(tensor)
|
||||||
|
|
||||||
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||||
|
"""
|
||||||
|
Paddle cutlass create weight process.
|
||||||
|
"""
|
||||||
|
self.weight_dtype = "int8"
|
||||||
|
self.ffn1_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.hidden_size // 2,
|
||||||
|
layer.moe_intermediate_size * 2,
|
||||||
|
]
|
||||||
|
self.ffn2_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.moe_intermediate_size // 2,
|
||||||
|
layer.hidden_size,
|
||||||
|
]
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
self.added_weight_attrs[0],
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.ffn1_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
self.added_weight_attrs[1],
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.ffn2_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.create_w4afp8_scale_weights(layer, layer.weight_key_map)
|
||||||
|
|
||||||
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
|
"""
|
||||||
|
Paddle cutlass load weight process.
|
||||||
|
"""
|
||||||
|
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||||
|
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
|
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||||
|
weight_name = self.added_weight_attrs[idx]
|
||||||
|
weight_list = []
|
||||||
|
for i in range(layer.num_local_experts):
|
||||||
|
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
|
||||||
|
weight_list.append(quant_weight)
|
||||||
|
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||||
|
getattr(layer, weight_name).set_value(quanted_weight)
|
||||||
|
|
||||||
|
self.load_w4afp8_scale_weights(layer, layer.weight_key_map, state_dict)
|
||||||
|
|
||||||
|
def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
|
||||||
|
"""
|
||||||
|
Get w4afp8 weights from state dict and process them.
|
||||||
|
Args:
|
||||||
|
layer (nn.Layer): The layer to add parameters to.
|
||||||
|
weight_key_map (dict): The weight key map.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.default_dtype = layer._helper.get_default_dtype()
|
||||||
|
if layer.ep_size > 1:
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
"up_gate_proj_in_scale_all_experts",
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=[layer.num_experts],
|
||||||
|
dtype="float32",
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# in_scales
|
||||||
|
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
in_scale_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts],
|
||||||
|
dtype="float32",
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# weight_scales
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
"up_gate_proj_weight_scale",
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||||
|
dtype="float32",
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
"down_proj_weight_scale",
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts, layer.hidden_size],
|
||||||
|
dtype="float32",
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
|
||||||
|
"""
|
||||||
|
Get w4afp8 weights from state dict and process them.
|
||||||
|
Args:
|
||||||
|
layer (nn.Layer): The layer to add parameters to.
|
||||||
|
weight_key_map (dict): The weight key map.
|
||||||
|
state_dict (dict): The state dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _extract_scale_tensor(state_dict, key_template, expert_idx):
|
||||||
|
return get_tensor(state_dict.pop(key_template.format(expert_idx)))
|
||||||
|
|
||||||
|
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
||||||
|
processed_in_scale = 1 / paddle.concat(in_scales)
|
||||||
|
getattr(layer, name).set_value(processed_in_scale)
|
||||||
|
return processed_in_scale
|
||||||
|
|
||||||
|
def _permute_weight_scale(weight_scale: paddle.Tensor):
|
||||||
|
weight_scale = w4afp8_gemm_scale_permute(weight_scale)
|
||||||
|
return weight_scale
|
||||||
|
|
||||||
|
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
|
||||||
|
processed_weight_scale = (
|
||||||
|
paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None]
|
||||||
|
)
|
||||||
|
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
||||||
|
getattr(layer, name).set_value(processed_weight_scale)
|
||||||
|
|
||||||
|
# 1. Init scale containers and maps
|
||||||
|
up_gate_proj_weight_scales = []
|
||||||
|
down_proj_weight_scales = []
|
||||||
|
up_gate_proj_in_scales_all_experts = []
|
||||||
|
up_gate_proj_in_scales = []
|
||||||
|
down_proj_in_scales = []
|
||||||
|
|
||||||
|
scale_weight_map = {
|
||||||
|
"up_gate_proj_weight_scale": up_gate_proj_weight_scales,
|
||||||
|
"down_proj_weight_scale": down_proj_weight_scales,
|
||||||
|
"up_gate_proj_in_scale": up_gate_proj_in_scales,
|
||||||
|
"down_proj_in_scale": down_proj_in_scales,
|
||||||
|
}
|
||||||
|
scale_key_map = {
|
||||||
|
"up_gate_proj_weight_scale": weight_key_map.get("up_gate_proj_expert_weight_scale_key", None),
|
||||||
|
"down_proj_weight_scale": weight_key_map.get("down_proj_expert_weight_scale_key", None),
|
||||||
|
"up_gate_proj_in_scale": weight_key_map.get("up_gate_proj_expert_in_scale_key", None),
|
||||||
|
"down_proj_in_scale": weight_key_map.get("down_proj_expert_in_scale_key", None),
|
||||||
|
}
|
||||||
|
for name, value in scale_key_map.items():
|
||||||
|
if value is None:
|
||||||
|
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
|
||||||
|
|
||||||
|
# 2. Extract scale tensor from state dict
|
||||||
|
if layer.ep_size > 1:
|
||||||
|
for expert_idx in range(layer.num_experts):
|
||||||
|
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
|
||||||
|
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
|
||||||
|
getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
|
||||||
|
paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||||
|
)
|
||||||
|
|
||||||
|
for local_expert_idx in range(layer.num_local_experts):
|
||||||
|
expert_idx = local_expert_idx + layer.expert_id_offset
|
||||||
|
for name, scale_key_template in scale_key_map.items():
|
||||||
|
scale_tensor = _extract_scale_tensor(state_dict, scale_key_template, expert_idx)
|
||||||
|
scale_weight_map[name].append(scale_tensor)
|
||||||
|
|
||||||
|
# 3. Process scale tensor and set to layer
|
||||||
|
in_scales = []
|
||||||
|
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||||
|
in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name]))
|
||||||
|
|
||||||
|
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
||||||
|
_process_weight_scale(
|
||||||
|
weight_scale_name,
|
||||||
|
scale_weight_map[weight_scale_name],
|
||||||
|
in_scales[i],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||||
"""
|
"""
|
||||||
weight only for moe
|
weight only for moe
|
||||||
|
@@ -20,6 +20,7 @@ import paddle
|
|||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
|
|
||||||
|
from ..moe import FusedMoE
|
||||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||||
|
|
||||||
QUANT_SCALING_FACTOR = 448
|
QUANT_SCALING_FACTOR = 448
|
||||||
@@ -30,24 +31,32 @@ class W4AFP8Config(QuantConfigBase):
|
|||||||
quantization config for weight 4bits and activation fp8
|
quantization config for weight 4bits and activation fp8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
|
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight_scale_dict = weight_scale_dict
|
self.weight_scale_dict = weight_scale_dict
|
||||||
self.act_scale_dict = act_scale_dict
|
self.act_scale_dict = act_scale_dict
|
||||||
self.quant_max_bound = 448
|
self.quant_max_bound = 448
|
||||||
self.quant_min_bound = -448
|
self.quant_min_bound = -448
|
||||||
self.quant_round_type = 1
|
self.quant_round_type = 1
|
||||||
|
self.is_permuted = is_permuted
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "w4afp8"
|
return "w4afp8"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "W4AFP8Config":
|
def from_config(cls, config: dict) -> "W4AFP8Config":
|
||||||
weight_scale_dict = config["weight_scale_dict"]
|
weight_scale_dict = config.get("weight_scale_dict", None)
|
||||||
act_scale_dict = config["act_scale_dict"]
|
act_scale_dict = config.get("act_scale_dict", None)
|
||||||
return cls(weight_scale_dict, act_scale_dict)
|
is_permuted = config.get("is_permuted", True)
|
||||||
|
return cls(weight_scale_dict, act_scale_dict, is_permuted)
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
|
if isinstance(layer, FusedMoE):
|
||||||
|
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
|
||||||
|
CutlassW4AFP8MoEMethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CutlassW4AFP8MoEMethod(self)
|
||||||
return W4AFP8LinearMethod(self)
|
return W4AFP8LinearMethod(self)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -103,7 +103,7 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||||
moe_quant_type = fd_config.quant_config.moe_quant_type
|
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||||
|
|
||||||
if moe_quant_type == "w4a8":
|
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
|
||||||
weight_key_map = {
|
weight_key_map = {
|
||||||
"gate_weight_key": f"{prefix}.gate.weight",
|
"gate_weight_key": f"{prefix}.gate.weight",
|
||||||
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
||||||
|
@@ -31,7 +31,7 @@ def w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BA
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def peruate_scale(weight_scale):
|
def permute_scale(weight_scale):
|
||||||
weight_scale = weight_scale.reshape([BATCH, N])
|
weight_scale = weight_scale.reshape([BATCH, N])
|
||||||
temp = paddle.zeros([16])
|
temp = paddle.zeros([16])
|
||||||
for b in range(BATCH):
|
for b in range(BATCH):
|
||||||
@@ -52,10 +52,10 @@ TokenPadding = 0
|
|||||||
|
|
||||||
tokens = [tokens_per_group] * BATCH
|
tokens = [tokens_per_group] * BATCH
|
||||||
tokens_perfix_sum = np.cumsum(tokens)
|
tokens_perfix_sum = np.cumsum(tokens)
|
||||||
tokens_perfix_sum = np.insert(tokens_perfix_sum, 0, 0)
|
|
||||||
|
|
||||||
tokens = paddle.to_tensor(tokens, dtype="int32")
|
|
||||||
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int32")
|
tokens = paddle.to_tensor(tokens, dtype="int64")
|
||||||
|
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int64")
|
||||||
|
|
||||||
all_tokens = int(tokens.sum())
|
all_tokens = int(tokens.sum())
|
||||||
|
|
||||||
@@ -72,7 +72,7 @@ input_row_sum = input_bf16.sum(axis=1) * -7 / 512
|
|||||||
max_tokens = int(tokens.max())
|
max_tokens = int(tokens.max())
|
||||||
|
|
||||||
out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N)
|
out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N)
|
||||||
weight_dequant_scale = paddle.to_tensor(peruate_scale(weight_dequant_scale) * 512)
|
weight_dequant_scale = paddle.to_tensor(permute_scale(weight_dequant_scale) * 512)
|
||||||
|
|
||||||
weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu())
|
weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu())
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user