mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
Compare commits
26 Commits
remove_use
...
feature/on
Author | SHA1 | Date | |
---|---|---|---|
![]() |
9f1882d9a8 | ||
![]() |
5223065d59 | ||
![]() |
c753f1fc9e | ||
![]() |
62659a7a73 | ||
![]() |
4f17f9aa6e | ||
![]() |
7642611b12 | ||
![]() |
2513cd929b | ||
![]() |
4dbaa3d74c | ||
![]() |
44043e0c88 | ||
![]() |
7b8db880b7 | ||
![]() |
c7993d35cb | ||
![]() |
c7cb31051b | ||
![]() |
5e7ab3dfe3 | ||
![]() |
abed681444 | ||
![]() |
548f53e433 | ||
![]() |
ee742f55f1 | ||
![]() |
794ab9705f | ||
![]() |
0e0891ad12 | ||
![]() |
9e87f3341b | ||
![]() |
869626b0f4 | ||
![]() |
1b1287e145 | ||
![]() |
9307f2619b | ||
![]() |
fbe03866d1 | ||
![]() |
89ad20bea2 | ||
![]() |
02398135a8 | ||
![]() |
d65a0a6a2c |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -164,3 +164,7 @@ build
|
||||
.ccls-cache
|
||||
|
||||
third_party
|
||||
|
||||
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
|
@@ -286,6 +286,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
@@ -309,6 +310,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
|
@@ -199,8 +199,9 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -512,7 +513,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -555,8 +557,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -633,10 +636,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T *cache_v_scale_cur = cache_v_scale + v_head_idx * HeadDim + head_bias;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
if constexpr (!is_scale_channel_wise) {
|
||||
scale = __ldg(&cache_k_scale[kv_head_idx]);
|
||||
}
|
||||
@@ -763,7 +767,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -813,9 +818,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -908,6 +913,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T *cache_v_scale_cur = cache_v_scales + v_head_idx * HeadDim + head_bias;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
|
@@ -248,7 +248,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -271,7 +272,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
@@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int64_t base_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||
h_bias;
|
||||
@@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
// do rope
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
@@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const bool rope_3d,
|
||||
const cudaStream_t &stream) {
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int PackSize = 16 / sizeof(T);
|
||||
@@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -890,7 +894,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const int kv_token_num,
|
||||
const int max_seq_len,
|
||||
const std::string& cache_quant_type) {
|
||||
const std::string& cache_quant_type,
|
||||
const bool rope_3d) {
|
||||
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -953,8 +958,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.dims()[2],
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
stream);
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
|
@@ -193,7 +193,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -253,8 +254,9 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -476,7 +478,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -522,8 +525,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
}
|
||||
@@ -583,10 +587,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T scale;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
} else {
|
||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||
|
@@ -39,7 +39,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
|
||||
const uint32_t elem_nums =
|
||||
@@ -96,7 +97,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +127,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -191,7 +194,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -313,6 +317,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
@@ -368,7 +373,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
@@ -401,7 +407,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
@@ -434,7 +441,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
@@ -500,6 +508,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
@@ -526,6 +535,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
@@ -551,6 +561,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
@@ -578,6 +589,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
|
@@ -35,6 +35,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
|
@@ -107,7 +107,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const int kv_token_num, const int max_seq_len,
|
||||
const std::string &cache_quant_type);
|
||||
const std::string &cache_quant_type,
|
||||
const bool rope_3d);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
|
||||
@@ -188,7 +189,8 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency);
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
@@ -334,6 +336,19 @@ void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
||||
const paddle::Tensor &text_input,
|
||||
const paddle::Tensor &image_input);
|
||||
|
||||
void LimitContentLen(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& end_thinking_tokens,
|
||||
const paddle::Tensor& max_content_len,
|
||||
const paddle::Tensor& max_think_len,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& limit_content_status,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input,
|
||||
paddle::Tensor &token_type_ids,
|
||||
@@ -603,7 +618,7 @@ void SpeculateVerify(
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
@@ -648,6 +663,20 @@ void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
void HybridMtpNgram(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int min_ngram_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
// MTP
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -664,8 +693,10 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
@@ -1082,7 +1113,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
|
||||
|
||||
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
|
||||
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
|
||||
|
||||
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
|
||||
|
||||
@@ -1092,6 +1123,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
||||
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
|
||||
|
||||
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
|
||||
|
@@ -193,6 +193,12 @@ public:
|
||||
typedef uint8_t data_t;
|
||||
};
|
||||
|
||||
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
|
||||
public:
|
||||
typedef __nv_fp8_e4m3 DataType;
|
||||
typedef paddle::float8_e4m3fn data_t;
|
||||
};
|
||||
|
||||
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
|
||||
|
186
custom_ops/gpu_ops/limit_content_len.cu
Normal file
186
custom_ops/gpu_ops/limit_content_len.cu
Normal file
@@ -0,0 +1,186 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void limit_content_len(
|
||||
int64_t* next_tokens,
|
||||
const int64_t* end_thinking_tokens,
|
||||
int* max_content_lens,
|
||||
const int* max_think_lens,
|
||||
int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int64_t* max_dec_lens,
|
||||
int* limit_content_status,
|
||||
const bool* enable_thinking,
|
||||
int* accept_num,
|
||||
int* seq_lens_decoder,
|
||||
bool* stop_flags,
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int end_thinking_token_num,
|
||||
const int eos_token_id_len) {
|
||||
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= bs) return;
|
||||
|
||||
if (!enable_thinking[idx]) return;
|
||||
|
||||
const int original_accept_num = accept_num[idx];
|
||||
if (original_accept_num <= 0) return;
|
||||
|
||||
int current_limit_content_status = limit_content_status[idx];
|
||||
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
|
||||
if (current_limit_content_status == 2 && stop_flags[idx]) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int max_think_len_reg = max_think_lens[idx];
|
||||
|
||||
const int64_t end_thinking_token_reg = end_thinking_tokens[0];
|
||||
|
||||
int64_t current_max_dec_len = max_dec_lens[idx];
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
const int64_t current_base_step = step_idx[idx] - original_accept_num + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num; token_offset++) {
|
||||
const int token_idx = idx * tokens_per_step + token_offset;
|
||||
int64_t next_token_reg = next_tokens[token_idx];
|
||||
const int64_t current_step = current_base_step + token_offset;
|
||||
|
||||
bool condition_triggered = false;
|
||||
bool is_eos = false;
|
||||
|
||||
// ======================= 思考阶段控制 =======================
|
||||
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
|
||||
if (current_limit_content_status < 1) {
|
||||
bool should_transform = false;
|
||||
|
||||
// 当开启思考长度控制时,检查是否超时
|
||||
if (max_think_len_reg > 0 && current_step >= max_think_len_reg) {
|
||||
should_transform = true;
|
||||
} else {
|
||||
// 检查是否生成了EOS
|
||||
for (int j = 0; j < eos_token_id_len; j++) {
|
||||
if (eos_token_ids[j] == next_token_reg) {
|
||||
is_eos = true;
|
||||
should_transform = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (should_transform) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token_reg = end_thinking_token_reg;
|
||||
// 将状态推进到 1, 表示 "正在结束思考"
|
||||
current_limit_content_status = 1;
|
||||
condition_triggered = true; // 因为修改了token,需要截断
|
||||
// 只在EOS触发时清除stop_flags
|
||||
if (is_eos && stop_flags[idx]) {
|
||||
stop_flags[idx] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ======================= 思考结束处理 =======================
|
||||
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
|
||||
// 这种情况会处理两种场景:
|
||||
// 1. status == 0: 模型自己生成了 end_thinking_token
|
||||
// 2. status == 1: 上一阶段强制注入了 end_thinking_token
|
||||
if (current_limit_content_status < 2) {
|
||||
if (next_token_reg == end_thinking_token_reg) {
|
||||
// 确认思考结束,将状态推进到 2 (响应阶段)
|
||||
current_limit_content_status = 2;
|
||||
}
|
||||
}
|
||||
|
||||
next_tokens[token_idx] = next_token_reg;
|
||||
|
||||
if (condition_triggered) {
|
||||
new_accept_num = token_offset + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 更新全局状态
|
||||
int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[idx] -= discarded_tokens;
|
||||
seq_lens_decoder[idx] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[idx] = new_accept_num;
|
||||
limit_content_status[idx] = current_limit_content_status;
|
||||
max_dec_lens[idx] = current_max_dec_len;
|
||||
}
|
||||
|
||||
void LimitContentLen(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& end_thinking_tokens,
|
||||
const paddle::Tensor& max_content_len,
|
||||
const paddle::Tensor& max_think_len,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& limit_content_status,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags) {
|
||||
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int tokens_per_step = next_tokens.shape()[1];
|
||||
const int end_thinking_token_num = end_thinking_tokens.shape()[0];
|
||||
const int end_length = eos_token_ids.shape()[0];
|
||||
PD_CHECK(end_thinking_token_num == 1, "limit_content_len only support end_thinking_token_num = 1 for now.");
|
||||
|
||||
dim3 grid(1);
|
||||
dim3 block(1024);
|
||||
|
||||
limit_content_len<<<grid, block>>>(
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
end_thinking_tokens.data<int64_t>(),
|
||||
const_cast<int *>(max_content_len.data<int>()),
|
||||
max_think_len.data<int>(),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
|
||||
const_cast<int *>(limit_content_status.data<int>()),
|
||||
enable_thinking.data<bool>(),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
tokens_per_step,
|
||||
batch_size,
|
||||
end_thinking_token_num,
|
||||
end_length);
|
||||
}
|
||||
PD_BUILD_STATIC_OP(limit_content_len)
|
||||
.Inputs({"next_tokens",
|
||||
"end_thinking_tokens",
|
||||
"max_content_len",
|
||||
"max_think_len",
|
||||
"step_idx",
|
||||
"eos_token_ids",
|
||||
"max_dec_len",
|
||||
"limit_content_status",
|
||||
"enable_thinking",
|
||||
"accept_num",
|
||||
"seq_lens_decoder",
|
||||
"stop_flags"})
|
||||
.Outputs({"next_tokens_out", "max_dec_len_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"},
|
||||
{"max_dec_len", "max_dec_len_out"}})
|
||||
.SetKernelFn(PD_KERNEL(LimitContentLen));
|
@@ -269,7 +269,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int RoundType = 1>
|
||||
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int Kthread = 512, int RoundType = 1>
|
||||
__global__ void permute_x_kernel(const T *src_x,
|
||||
const int64_t *topk_idx,
|
||||
const float *topk_weights,
|
||||
@@ -285,9 +285,9 @@ __global__ void permute_x_kernel(const T *src_x,
|
||||
int *dst_indices,
|
||||
int *cumsum_idx_gpu,
|
||||
int64_t *token_nums_per_expert_cumsum,
|
||||
int64_t *expert_idx_per_token,
|
||||
int64_t *expert_idx_per_token, // [num_rows, moe_topk]
|
||||
float max_bound = 127.0,
|
||||
float min_bound = -127.0) { // [num_rows, moe_topk]
|
||||
float min_bound = -127.0) {
|
||||
const int src_token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int vec_size = sizeof(int4) / sizeof(T);
|
||||
@@ -330,10 +330,17 @@ __global__ void permute_x_kernel(const T *src_x,
|
||||
if (up_gate_proj_in_scale) {
|
||||
for (int i = 0; i < vec_size; i++) {
|
||||
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
|
||||
if (RoundType == 0) {
|
||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
|
||||
if constexpr (std::is_same<OutT, int8_t>::value) {
|
||||
// w4aint8
|
||||
if (RoundType == 0) {
|
||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
|
||||
} else {
|
||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(round(quant_value), min_bound, max_bound));
|
||||
}
|
||||
} else {
|
||||
res_vec[i] = static_cast<OutT>(round(quant_value));
|
||||
// w4afp8
|
||||
float value = ClipFunc<float>(quant_value, min_bound, max_bound);
|
||||
res_vec[i] = static_cast<OutT>(value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -373,6 +380,10 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||
typedef typename traits_fp8::DataType DataType_fp8;
|
||||
typedef typename traits_fp8::data_t data_t_fp8;
|
||||
|
||||
auto stream = input.stream();
|
||||
auto place = input.place();
|
||||
const int gridx = min(132 * 8, num_rows);
|
||||
@@ -420,6 +431,50 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
} else if (moe_quant_type == "w4afp8") {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
@@ -493,7 +548,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
|
||||
auto permute_input = GetEmptyTensor(
|
||||
{token_nums_this_rank, hidden_size},
|
||||
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : input_type,
|
||||
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN : input_type,
|
||||
place);
|
||||
auto num_experts_per_rank_tensor = GetEmptyTensor(
|
||||
{num_experts_per_rank},
|
||||
|
@@ -88,7 +88,7 @@ struct nv_type_traits<int8_t> {
|
||||
constexpr int kLogN = 7; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupport!", logN)); \
|
||||
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \
|
||||
}
|
||||
|
||||
#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \
|
||||
@@ -108,7 +108,7 @@ struct nv_type_traits<int8_t> {
|
||||
constexpr int VEC_SIZE = 1; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupport!", vec_size)); \
|
||||
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", vec_size)); \
|
||||
}
|
||||
|
||||
#define DISPATCH_logN(logN, kLogN, ...) \
|
||||
@@ -605,26 +605,6 @@ void moe_fast_hardamard_kernel(const T *x,
|
||||
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
|
||||
}
|
||||
if constexpr (kNChunks > 1) {
|
||||
// T x_vals_transposed[VecSize][kNChunks] = {init_value};
|
||||
// #pragma unroll
|
||||
// for (int c = 0; c < kNChunks; ++c) {
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < VecSize; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; }
|
||||
// }
|
||||
// if constexpr (kNChunks == 28) {
|
||||
// hadamard_mult_thread_chunk_28<VecSize>(x_vals_transposed);
|
||||
// } else if constexpr (kNChunks == 36) {
|
||||
// hadamard_mult_thread_chunk_36<VecSize>(x_vals_transposed);
|
||||
// } else {
|
||||
// constexpr int kLogNChunks = cilog2(kNChunks);
|
||||
// static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
|
||||
// hadamard_mult_thread<kLogNChunks, VecSize>(x_vals_transposed);
|
||||
// }
|
||||
// #pragma unroll
|
||||
// for (int c = 0; c < kNChunks; ++c) {
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < VecSize; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; }
|
||||
// }
|
||||
if constexpr (kNChunks == 28) {
|
||||
hadamard_mult_thread_28_transpose<T, VecSize>(x_vals);
|
||||
} else if constexpr (kNChunks == 36) {
|
||||
|
@@ -72,6 +72,285 @@ __host__ __device__ constexpr static U arrayConvert(T const& input)
|
||||
return u;
|
||||
}
|
||||
|
||||
struct uint8 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
};
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<>
|
||||
struct BytesToType<32> {
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
template <template <typename> class ReductionOp, typename T, int block_size>
|
||||
__inline__ __device__ T BlockAllReduce(T val) {
|
||||
typedef cub::BlockReduce<T, block_size> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T result_broadcast;
|
||||
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
|
||||
if (threadIdx.x == 0) {
|
||||
result_broadcast = result;
|
||||
}
|
||||
__syncthreads();
|
||||
return result_broadcast;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
|
||||
const float scale,
|
||||
const float max_bound,
|
||||
const float min_bound) {
|
||||
float quant_value = max_bound * scale * static_cast<float>(input);
|
||||
return static_cast<OutType>(ClipFunc<float>(quant_value, min_bound, max_bound));
|
||||
}
|
||||
|
||||
template <typename T, typename OutT, int VecSize, int Kthread>
|
||||
__global__ void masked_quantize_moe_input_kernel(const T* permuted_inputs,
|
||||
const int64_t* expert_idx_per_token,
|
||||
const float* quant_scales,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert,
|
||||
OutT* out) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadOutT = AlignedVector<OutT, VecSize>;
|
||||
LoadT input_vec;
|
||||
LoadOutT output_vec;
|
||||
float scale_factor = -7.0f / 512.0f;
|
||||
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
|
||||
const auto expert_id = token_idx / num_max_tokens_per_expert;
|
||||
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
|
||||
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
|
||||
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
|
||||
token_idx += num_iters_to_next_expert * gridDim.x;
|
||||
continue;
|
||||
}
|
||||
int64_t expert_idx = expert_idx_per_token[token_idx];
|
||||
float quant_scale = quant_scales[expert_idx];
|
||||
float thread_row_sum = 0.0f;
|
||||
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
|
||||
thread_row_sum += static_cast<float>(output_vec[i]);
|
||||
}
|
||||
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
|
||||
}
|
||||
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename OutT, int VecSize, int Kthread>
|
||||
__global__ void quantize_moe_input_kernel(const T* permuted_inputs,
|
||||
const int64_t* expert_idx_per_token,
|
||||
const float* quant_scales,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert,
|
||||
OutT* out) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadOutT = AlignedVector<OutT, VecSize>;
|
||||
LoadT input_vec;
|
||||
LoadOutT output_vec;
|
||||
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
|
||||
float scale_factor = -7.0f / 512.0f;
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||
int64_t expert_idx = expert_idx_per_token[token_idx];
|
||||
float quant_scale = quant_scales[expert_idx];
|
||||
float thread_row_sum = 0.0f;
|
||||
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
|
||||
thread_row_sum += static_cast<float>(output_vec[i]);
|
||||
}
|
||||
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
|
||||
}
|
||||
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void quantize_moe_input(
|
||||
const T* permuted_inputs,
|
||||
const int64_t* expert_idx_per_token,
|
||||
const float* quant_scales,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
OutT* out,
|
||||
cudaStream_t stream) {
|
||||
constexpr int VecSize = 16 / sizeof(T);
|
||||
constexpr int threads_per_block = 128;
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
auto kernel = used_in_ep_low_latency ? masked_quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block> : quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block>;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, threads_per_block, 0);
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
dim3 grid;
|
||||
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||
kernel<<<grid, threads_per_block, 0, stream>>>(
|
||||
permuted_inputs,
|
||||
expert_idx_per_token,
|
||||
quant_scales,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
permuted_input_row_sum,
|
||||
recv_expert_count,
|
||||
num_max_tokens_per_expert,
|
||||
out);
|
||||
}
|
||||
|
||||
template <typename T, int VecSize, int Kthread>
|
||||
__global__ void masked_compute_row_sum_kernel(
|
||||
const T* permuted_inputs,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT input_vec;
|
||||
float scale_factor = -7.0f / 512.0f;
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
|
||||
const auto expert_id = token_idx / num_max_tokens_per_expert;
|
||||
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
|
||||
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
|
||||
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
|
||||
token_idx += num_iters_to_next_expert * gridDim.x;
|
||||
continue;
|
||||
}
|
||||
float thread_row_sum = 0.0f;
|
||||
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
thread_row_sum += static_cast<float>(input_vec[i]);
|
||||
}
|
||||
}
|
||||
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize, int Kthread>
|
||||
__global__ void compute_row_sum_kernel(
|
||||
const T* permuted_inputs,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT input_vec;
|
||||
float scale_factor = -7.0f / 512.0f;
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||
float thread_row_sum = 0.0f;
|
||||
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
thread_row_sum += static_cast<float>(input_vec[i]);
|
||||
}
|
||||
}
|
||||
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_row_sum(
|
||||
const T* permuted_inputs,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
cudaStream_t stream) {
|
||||
constexpr int VecSize = 16 / sizeof(T);
|
||||
constexpr int threads_per_block = 128;
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
auto kernel = used_in_ep_low_latency ? masked_compute_row_sum_kernel<T, VecSize, threads_per_block> : compute_row_sum_kernel<T, VecSize, threads_per_block>;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, threads_per_block, 0);
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
dim3 grid;
|
||||
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||
kernel<<<grid, threads_per_block, 0, stream>>>(
|
||||
permuted_inputs,
|
||||
token_num,
|
||||
dim,
|
||||
permuted_input_row_sum,
|
||||
recv_expert_count,
|
||||
num_max_tokens_per_expert);
|
||||
}
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing
|
||||
// the output in the softmax kernel when we extend this module to support
|
||||
|
@@ -20,6 +20,7 @@
|
||||
#include "helper.h"
|
||||
#include "moe/fast_hardamard_kernel.h"
|
||||
#include "moe/fused_moe_helper.h"
|
||||
#include "w4afp8_gemm/w4afp8_gemm.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
@@ -33,7 +34,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out,
|
||||
bool used_in_ep_low_latency) {
|
||||
bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -60,19 +62,22 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
|
||||
Allocator* allocator = paddle::GetAllocator(place);
|
||||
Allocator::AllocationPtr workspace;
|
||||
if (quant_method == "weight_only_int4" || quant_method == "w4a8") {
|
||||
if (quant_method == "weight_only_int4" || quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||
inter_dim = inter_dim * 2;
|
||||
}
|
||||
if (quant_method == "w4a8") {
|
||||
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||
workspace = allocator->Allocate(
|
||||
SizeOf(paddle::DataType::INT8) * workspace_size);
|
||||
}
|
||||
|
||||
const int64_t inter_size = inter_dim;
|
||||
|
||||
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||
typedef typename traits_fp8::DataType DataType_fp8;
|
||||
typedef typename traits_fp8::data_t data_t_fp8;
|
||||
|
||||
int num_experts_ = num_experts;
|
||||
int num_max_tokens_per_expert;
|
||||
int num_max_tokens_per_expert = 256;
|
||||
int expanded_active_expert_rows;
|
||||
|
||||
paddle::Tensor fc1_out_tensor;
|
||||
@@ -161,13 +166,49 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
total_rows_in_ll_else_minus1,
|
||||
tune_total_rows,
|
||||
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
reinterpret_cast<char*>(workspace->ptr()),
|
||||
workspace_size,
|
||||
num_experts,
|
||||
stream);
|
||||
} else if (quant_method == "w4afp8") {
|
||||
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||
typedef typename traits_fp8::DataType DataType_fp8;
|
||||
typedef typename traits_fp8::data_t data_t_fp8;
|
||||
|
||||
Allocator::AllocationPtr ffn1_input_row_sum;
|
||||
ffn1_input_row_sum = allocator->Allocate(
|
||||
sizeof(float) * expanded_active_expert_rows);
|
||||
|
||||
compute_row_sum(
|
||||
permute_input.data<data_t_fp8>(),
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
stream);
|
||||
|
||||
|
||||
float* row_scale = nullptr;
|
||||
DisPatchW4AFp8GemmWrapper(
|
||||
reinterpret_cast<const DataType_fp8 *>(permute_input.data<data_t_fp8>()),
|
||||
reinterpret_cast<const DataType_fp8 *>(up_gate_proj_weight.data<int8_t>()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
|
||||
row_scale,
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
|
||||
->data<float>(),
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||
num_max_tokens_per_expert,
|
||||
num_experts,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
stream);
|
||||
} else {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
||||
fp16_moe_gemm_runner.moe_gemm_bias_act(
|
||||
@@ -194,7 +235,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
}
|
||||
auto act_out = act_out_tensor.data<data_t>();
|
||||
|
||||
if (quant_method == "weight_only_int8") {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
|
||||
int8_moe_gemm_runner.moe_gemm(
|
||||
@@ -267,13 +307,73 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
reinterpret_cast<NvType *>(ffn_out_data),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
total_rows_in_ll_else_minus1,
|
||||
tune_total_rows,
|
||||
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
reinterpret_cast<char*>(workspace->ptr()),
|
||||
workspace_size,
|
||||
num_experts,
|
||||
stream);
|
||||
} else if (quant_method == "w4afp8") {
|
||||
data_t *ffn2_shift = nullptr;
|
||||
data_t *ffn2_smooth = nullptr;
|
||||
float* row_scale = nullptr;
|
||||
Allocator::AllocationPtr fp8_act_out;
|
||||
fp8_act_out = allocator->Allocate(
|
||||
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
|
||||
Allocator::AllocationPtr ffn2_input_row_sum;
|
||||
ffn2_input_row_sum = allocator->Allocate(
|
||||
sizeof(float) * expanded_active_expert_rows);
|
||||
|
||||
// note(yuanxiaolan): optimize this
|
||||
MoeFastHardamardWrapper<data_t, data_t>(
|
||||
act_out_tensor.data<data_t>(),
|
||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
ffn2_shift, // ffn2_shift->data<T>(),
|
||||
ffn2_smooth, // ffn2_smooth->data<T>(),
|
||||
nullptr,
|
||||
1,
|
||||
448.0f,
|
||||
-448.0f,
|
||||
expanded_active_expert_rows,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
act_out_tensor.data<data_t>(),
|
||||
stream
|
||||
);
|
||||
|
||||
quantize_moe_input<data_t, data_t_fp8>(act_out_tensor.data<data_t>(),
|
||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
|
||||
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
|
||||
448.0f,
|
||||
-448.0f,
|
||||
expanded_active_expert_rows,
|
||||
inter_size / 2,
|
||||
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
reinterpret_cast<data_t_fp8 *>(fp8_act_out->ptr()),
|
||||
stream
|
||||
);
|
||||
|
||||
DisPatchW4AFp8GemmWrapper(
|
||||
reinterpret_cast<const DataType_fp8 *>(fp8_act_out->ptr()),
|
||||
reinterpret_cast<const DataType_fp8 *>(down_proj_weight.data<int8_t>()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
|
||||
row_scale,
|
||||
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
|
||||
->data<float>(),
|
||||
reinterpret_cast<NvType*>(ffn_out_data),
|
||||
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||
num_max_tokens_per_expert,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
} else {
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
||||
fp16_moe_gemm_runner.moe_gemm(
|
||||
@@ -302,10 +402,12 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency) {
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
|
||||
cudaCheckError();
|
||||
const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype();
|
||||
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||
|
||||
switch (t_type) {
|
||||
@@ -320,7 +422,9 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
down_proj_in_scale,
|
||||
expert_idx_per_token,
|
||||
quant_method,
|
||||
ffn_out, used_in_ep_low_latency);
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
@@ -333,7 +437,9 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
down_proj_in_scale,
|
||||
expert_idx_per_token,
|
||||
quant_method,
|
||||
ffn_out, used_in_ep_low_latency);
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
@@ -351,7 +457,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency) {
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
return {MoeExpertFFNFunc(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
@@ -361,7 +468,9 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
down_proj_scale,
|
||||
down_proj_in_scale,
|
||||
expert_idx_per_token,
|
||||
quant_method, used_in_ep_low_latency)};
|
||||
quant_method,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
@@ -375,7 +484,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_in_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency) {
|
||||
const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
@@ -388,8 +498,9 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency) {
|
||||
if (quant_method == "w4a8") {
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||
return {up_gate_proj_scale_dtype.get()};
|
||||
} else {
|
||||
return {permute_input_dtype};
|
||||
@@ -460,7 +571,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
paddle::Optional("down_proj_in_scale"),
|
||||
paddle::Optional("expert_idx_per_token")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
||||
|
@@ -26,8 +26,10 @@ __global__ void process_splitwise_prefill(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
@@ -35,11 +37,12 @@ __global__ void process_splitwise_prefill(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
@@ -92,8 +95,10 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
@@ -101,11 +106,12 @@ __global__ void draft_model_preprocess_kernel(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
@@ -113,13 +119,16 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid < bsz) {
|
||||
auto base_model_step_idx_now = base_model_step_idx[tid];
|
||||
const int32_t base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
|
||||
auto accept_num_now = accept_num[tid];
|
||||
const int32_t accept_num_now = accept_num[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
|
||||
const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
|
||||
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
@@ -133,14 +142,12 @@ __global__ void draft_model_preprocess_kernel(
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
if (base_model_step_idx_now == 0) {
|
||||
seq_lens_this_time[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
} else if (seq_lens_encoder[tid] > 0) {
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
@@ -149,24 +156,24 @@ __global__ void draft_model_preprocess_kernel(
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else if (accept_num_now <=
|
||||
max_draft_token) /*Accept partial draft tokens*/ {
|
||||
// Base Model reject stop
|
||||
} else {
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||
step_idx[tid] = base_model_step_idx[tid];
|
||||
// TODO: check
|
||||
seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time;
|
||||
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
|
||||
} else {
|
||||
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||
step_idx[tid] -= max_draft_token - accept_num_now;
|
||||
// 2: Last base model generated token and first MTP token
|
||||
seq_lens_decoder[tid] -= num_model_step - 1;
|
||||
step_idx[tid] -= num_model_step - 1;
|
||||
}
|
||||
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
draft_tokens_now[0] = modified_token;
|
||||
seq_lens_this_time[tid] = 1;
|
||||
|
||||
} else /*Accept all draft tokens*/ {
|
||||
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||
seq_lens_this_time[tid] = 2;
|
||||
for (int i = 0; i < accept_num_now; i++) {
|
||||
draft_tokens_now[i] = accept_tokens_now[i];
|
||||
const int pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i);
|
||||
const int64_t accept_token = accept_tokens_now[i];
|
||||
pre_ids_now[pre_id_pos] = accept_token;
|
||||
}
|
||||
seq_lens_this_time[tid] = accept_num_now;
|
||||
}
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
@@ -194,8 +201,10 @@ void DispatchRunner(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
@@ -203,11 +212,12 @@ void DispatchRunner(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
@@ -222,8 +232,10 @@ void DispatchRunner(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
@@ -231,11 +243,12 @@ void DispatchRunner(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
@@ -248,8 +261,10 @@ void DispatchRunner(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
@@ -257,11 +272,12 @@ void DispatchRunner(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,8 +292,10 @@ void DispatchTokenMode(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
@@ -285,11 +303,12 @@ void DispatchTokenMode(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
@@ -304,8 +323,10 @@ void DispatchTokenMode(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
@@ -313,11 +334,12 @@ void DispatchTokenMode(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
@@ -332,8 +354,10 @@ void DispatchTokenMode(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
@@ -341,11 +365,12 @@ void DispatchTokenMode(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
@@ -363,21 +388,24 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
int draft_tokens_len = draft_tokens.shape()[1];
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
constexpr int BlockSize = 512;
|
||||
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
|
||||
@@ -395,8 +423,10 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
@@ -404,11 +434,12 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
|
||||
@@ -429,8 +460,10 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder",
|
||||
"base_model_seq_lens_decoder",
|
||||
"base_model_step_idx",
|
||||
@@ -445,8 +478,9 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder_out",
|
||||
"step_idx_out",
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out"})
|
||||
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
@@ -455,5 +489,6 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"batch_drop", "batch_drop_out"}})
|
||||
{"batch_drop", "batch_drop_out"},
|
||||
{"pre_ids", "pre_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
|
||||
|
@@ -63,10 +63,9 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
for (int i = 0; i < seq_len_this_time; ++i) {
|
||||
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
|
||||
}
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
pre_ids_now[step_idx[tid]] = token_this_time;
|
||||
|
||||
|
||||
} else {
|
||||
token_this_time = next_tokens_start[0];
|
||||
|
@@ -49,9 +49,7 @@ __global__ void ComputeOrderKernel(
|
||||
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
// 2. base model encoder. Base step=0
|
||||
} else if (cur_base_model_seq_lens_encoder != 0) {
|
||||
// 3. New end
|
||||
// 2. Base model stop at last verify-step.
|
||||
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: base=0. draft !=0 \n", i);
|
||||
@@ -61,20 +59,25 @@ __global__ void ComputeOrderKernel(
|
||||
// 4. stopped
|
||||
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
|
||||
} else {
|
||||
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
||||
#endif
|
||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
} else /*Accept all draft tokens*/ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
||||
#endif
|
||||
position_map[in_offset + accept_num - 2] = out_offset++;
|
||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
for (int i = 0; i < accept_num; i++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
in_offset += cur_base_model_seq_lens_this_time - accept_num;
|
||||
// (liuzichang): Temperary Reserved for debug
|
||||
// if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||
// #ifdef DEBUG_EAGLE_KERNEL
|
||||
// printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
||||
// #endif
|
||||
// position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
// in_offset += cur_base_model_seq_lens_this_time;
|
||||
// } else /*Accept all draft tokens*/ {
|
||||
// #ifdef DEBUG_EAGLE_KERNEL
|
||||
// printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
||||
// #endif
|
||||
// position_map[in_offset + accept_num - 2] = out_offset++;
|
||||
// position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
// in_offset += cur_base_model_seq_lens_this_time;
|
||||
// }
|
||||
}
|
||||
}
|
||||
output_token_num[0] = out_offset;
|
||||
|
@@ -0,0 +1,214 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
|
||||
int sum_mixed(const int *value, int num) {
|
||||
int sum_value = 0;
|
||||
for (int i = 0; i <= num; i++) {
|
||||
sum_value += value[i];
|
||||
}
|
||||
return sum_value;
|
||||
}
|
||||
|
||||
void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
|
||||
const int64_t *input_ids_len,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *draft_token_num,
|
||||
int64_t *draft_tokens,
|
||||
int32_t *seq_lens_this_time,
|
||||
int32_t *seq_lens_decoder,
|
||||
int64_t *max_dec_len,
|
||||
int64_t input_ids_stride,
|
||||
int64_t pre_ids_stride,
|
||||
int64_t draft_tokens_stride,
|
||||
int64_t max_batch_size,
|
||||
int max_ngram_size = 3,
|
||||
int min_ngram_size = 1,
|
||||
const int max_draft_tokens = 10) {
|
||||
int threshold = 1024;
|
||||
// dynamic in future
|
||||
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
|
||||
if (env_var) {
|
||||
threshold = std::stoi(env_var);
|
||||
}
|
||||
int unprocessed_batch_size = 0;
|
||||
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
|
||||
if (seq_lens_decoder[batch_idx] > 0) {
|
||||
unprocessed_batch_size++;
|
||||
}
|
||||
}
|
||||
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
|
||||
const int ori_seq_len_this_time = seq_lens_this_time[batch_idx];
|
||||
int max_draft_tokens_query = std::min(static_cast<int64_t>(
|
||||
max_draft_tokens - ori_seq_len_this_time + 1), max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
|
||||
|
||||
if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
|
||||
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
|
||||
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
|
||||
const int64_t cur_step_idx = step_idx[batch_idx];
|
||||
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
|
||||
unprocessed_batch_size--;
|
||||
|
||||
auto sum_token_num = sum_mixed(seq_lens_this_time, batch_idx);
|
||||
int left_min_token_num = unprocessed_batch_size;
|
||||
|
||||
if (sum_token_num + max_draft_tokens_query + left_min_token_num > threshold) {
|
||||
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
|
||||
max_draft_tokens_query = std::min(max_draft_tokens_query, tmp_max_draft_tokens);
|
||||
}
|
||||
|
||||
if (sum_token_num + left_min_token_num >= threshold - 1) {
|
||||
continue;
|
||||
}
|
||||
bool match_global = false;
|
||||
// apply ngram_match in input_ids
|
||||
for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size && !match_global; --ngram_size) {
|
||||
// Extract the last n tokens as our search ngram
|
||||
if (cur_step_idx < ngram_size) {
|
||||
continue;
|
||||
}
|
||||
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
|
||||
|
||||
// Iterate through sliding windows of size ngram_size
|
||||
// bool match_input = false;
|
||||
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global; ++i) {
|
||||
// Check if the current window matches the ngram
|
||||
bool match_local = true;
|
||||
for (int j = 0; j < ngram_size; j++) {
|
||||
if (ngram[j] != cur_input_ids[i + j]) {
|
||||
match_local = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_local) {
|
||||
int64_t start_idx = i + ngram_size;
|
||||
int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_input_ids_len);
|
||||
if (start_idx >= end_idx)
|
||||
continue;
|
||||
|
||||
int64_t cur_draft_token_num = end_idx - start_idx;
|
||||
|
||||
seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
|
||||
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
|
||||
// To break the current batch_idx for-loop
|
||||
match_global = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// apply ngram_match in generated tokens
|
||||
if (!match_global) {
|
||||
for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global; ++i) {
|
||||
// Check if the current window matches the ngram
|
||||
bool match_local = true;
|
||||
|
||||
for (int j = 0; j < ngram_size; j++) {
|
||||
if (ngram[j] != cur_pre_ids[i + j]) {
|
||||
match_local = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_local) {
|
||||
int64_t start_idx = i + ngram_size;
|
||||
int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_step_idx);
|
||||
|
||||
int64_t cur_draft_token_num = end_idx - start_idx;
|
||||
|
||||
if (start_idx >= end_idx)
|
||||
continue;
|
||||
// printf("match in Output with Ngram_size %d. %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx, end_idx);
|
||||
|
||||
seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
|
||||
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
|
||||
match_global = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HybridMtpNgram(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int min_ngram_size,
|
||||
const int max_draft_tokens) {
|
||||
|
||||
auto input_ids_shape = input_ids.shape();
|
||||
const int64_t input_ids_stride = input_ids_shape[1];
|
||||
|
||||
auto pre_ids_shape = pre_ids.shape();
|
||||
const int64_t pre_ids_stride = pre_ids_shape[1];
|
||||
|
||||
auto draft_tokens_shape = draft_tokens.shape();
|
||||
const int64_t draft_tokens_stride = draft_tokens_shape[1];
|
||||
|
||||
const int64_t max_batch_size = seq_lens_this_time.shape()[0];
|
||||
|
||||
find_candidate_pred_tokens_mixed(input_ids.data<int64_t>(),
|
||||
input_ids_len.data<int64_t>(),
|
||||
pre_ids.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
draft_token_num.data<int>(),
|
||||
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
|
||||
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
|
||||
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
|
||||
input_ids_stride,
|
||||
pre_ids_stride,
|
||||
draft_tokens_stride,
|
||||
max_batch_size,
|
||||
max_ngram_size,
|
||||
min_ngram_size,
|
||||
max_draft_tokens);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
|
||||
.Inputs({"input_ids",
|
||||
"input_ids_len",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"draft_token_num",
|
||||
"draft_tokens",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"max_dec_len"})
|
||||
.Attrs({"max_ngram_size: int", "min_ngram_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
|
||||
.SetKernelFn(PD_KERNEL(HybridMtpNgram))
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}});
|
@@ -23,14 +23,7 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct msgdata {
|
||||
int64_t mtype;
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, accept_num*bsz, tokens...
|
||||
};
|
||||
#include "speculate_msg.h"
|
||||
|
||||
void SpeculateGetOutput(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
@@ -54,7 +47,7 @@ void SpeculateGetOutput(const paddle::Tensor& x,
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static struct msgdata msg_rcv;
|
||||
static struct speculate_msgdata msg_rcv;
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
|
||||
|
@@ -1,69 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void SpeculateHydraSetScoreThresholdKernel(
|
||||
float* threshold,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* accept_num,
|
||||
const int real_bsz,
|
||||
const float default_threshold = 0.3,
|
||||
const float upper_threshold = 0.8,
|
||||
const float lower_threshold = 0.0,
|
||||
const float threshold_step = 0.1,
|
||||
const float threshold_step_fac = 0.5) {
|
||||
for (int bid = threadIdx.x; bid < real_bsz; bid += blockDim.x) {
|
||||
if (seq_lens_encoder[bid] > 0) {
|
||||
threshold[bid] = default_threshold;
|
||||
} else if (seq_lens_this_time[bid] <= 1) {
|
||||
continue;
|
||||
} else if (accept_num[bid] >= seq_lens_this_time[bid] &&
|
||||
threshold[bid] >
|
||||
lower_threshold + threshold_step * threshold_step_fac) {
|
||||
threshold[bid] -= threshold_step * threshold_step_fac;
|
||||
} else if (accept_num[bid] < seq_lens_this_time[bid] &&
|
||||
threshold[bid] < upper_threshold - threshold_step) {
|
||||
threshold[bid] += threshold_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateHydraSetScoreThreshold(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& threshold) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
std::vector<int64_t> seq_lens_this_time_shape = seq_lens_this_time.shape();
|
||||
const int bsz = seq_lens_this_time_shape[0];
|
||||
|
||||
SpeculateHydraSetScoreThresholdKernel<<<1, 256, 0, cu_stream>>>(
|
||||
const_cast<float*>(threshold.data<float>()),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
accept_num.data<int>(),
|
||||
bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_hydra_set_score_threshold)
|
||||
.Inputs(
|
||||
{"seq_lens_this_time", "seq_lens_encoder", "accept_num", "threshold"})
|
||||
.Outputs({"threshold_out"})
|
||||
.SetInplaceMap({{"threshold", "threshold_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateHydraSetScoreThreshold));
|
@@ -1,68 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void hydra_update_this_time(int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const float* topk_scores,
|
||||
const float* score_threshold,
|
||||
int real_bsz,
|
||||
int idx) {
|
||||
int linear_idx = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
|
||||
if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] != 0) {
|
||||
if (topk_scores[linear_idx] > score_threshold[linear_idx] &&
|
||||
seq_lens_this_time[linear_idx] == idx + 1) {
|
||||
seq_lens_this_time[linear_idx]++;
|
||||
}
|
||||
} else if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] == 0) {
|
||||
seq_lens_this_time[linear_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HydraUpdateThisTime(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& topk_scores,
|
||||
const paddle::Tensor& score_threshold,
|
||||
const int real_bsz,
|
||||
const int idx) {
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
hydra_update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
topk_scores.data<float>(),
|
||||
score_threshold.data<float>(),
|
||||
real_bsz,
|
||||
idx);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_hydra_update_seqlens_this_time)
|
||||
.Inputs({"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"topk_scores",
|
||||
"score_threshold"})
|
||||
.Outputs({"seq_lens_this_time_out"})
|
||||
.Attrs({"real_bsz: int", "idx: int"})
|
||||
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
|
||||
.SetKernelFn(PD_KERNEL(HydraUpdateThisTime));
|
@@ -1,149 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void RebuildAppendPaddingKernel(
|
||||
T *out,
|
||||
const T *full_hidden_states,
|
||||
const int *cum_offset,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *output_padding_offset,
|
||||
const int seq_len,
|
||||
const int dim_embed,
|
||||
const size_t elem_nums) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
|
||||
const int out_token_id = i / dim_embed;
|
||||
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
|
||||
const int bi = ori_token_id / seq_len;
|
||||
int seq_id = 0;
|
||||
|
||||
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
|
||||
else if (seq_len_encoder[bi] != 0) {
|
||||
seq_id = seq_len_encoder[bi] - 1;
|
||||
}
|
||||
|
||||
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
|
||||
const int bias_idx = i % dim_embed;
|
||||
|
||||
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &out[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor& full_hidden_states,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const int max_seq_len) {
|
||||
// src: [token_num, dim_embed]
|
||||
// dst: [batch_size, 1, dim_embed]
|
||||
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
|
||||
int dim_embed = full_hidden_states.shape()[1];
|
||||
int output_token_num = output_padding_offset.shape()[0];
|
||||
int elem_nums = output_token_num * dim_embed;
|
||||
constexpr int PackSize = VEC_16B / sizeof(DataType_);
|
||||
assert(elem_nums % PackSize == 0);
|
||||
|
||||
auto out = paddle::full({output_token_num, dim_embed}, 0, full_hidden_states.dtype(), full_hidden_states.place());
|
||||
|
||||
int pack_num = elem_nums / PackSize;
|
||||
const int threads_per_block = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
|
||||
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
|
||||
cum_offsets.data<int32_t>(),
|
||||
seq_len_encoder.data<int32_t>(),
|
||||
seq_len_decoder.data<int32_t>(),
|
||||
output_padding_offset.data<int32_t>(),
|
||||
max_seq_len,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> RebuildAppendPadding(
|
||||
const paddle::Tensor& full_hidden_states,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const int max_seq_len) {
|
||||
|
||||
|
||||
switch (full_hidden_states.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(
|
||||
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
|
||||
default:
|
||||
PD_THROW("Unsupported data type.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> RebuildAppendPaddingInferShape(
|
||||
const std::vector<int64_t>& full_hidden_states_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& seq_len_encoder_shape,
|
||||
const std::vector<int64_t>& seq_len_decoder_shape,
|
||||
const std::vector<int64_t>& output_padding_offset_shape) {
|
||||
const int64_t output_token_num = output_padding_offset_shape[0];
|
||||
const int64_t dim_embed = full_hidden_states_shape[1];
|
||||
std::vector<int64_t> out_shape = {output_token_num, dim_embed};
|
||||
return {out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
|
||||
const paddle::DataType& full_hidden_states_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& seq_len_encoder_dtype,
|
||||
const paddle::DataType& seq_len_decoder_dtype,
|
||||
const paddle::DataType& output_padding_offset_dtype) {
|
||||
return {full_hidden_states_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
|
||||
.Inputs({"full_hidden_states",
|
||||
"cum_offsets",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"output_padding_offset"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));
|
@@ -23,14 +23,7 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, tokens
|
||||
};
|
||||
#include "speculate_msg.h"
|
||||
|
||||
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
@@ -62,7 +55,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
static struct msgdata msg_sed;
|
||||
static struct speculate_msgdata msg_sed;
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
|
@@ -15,7 +15,7 @@
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_update_v3(int *seq_lens_encoder,
|
||||
__global__ void speculate_update(int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
bool *not_need_stop,
|
||||
int64_t *draft_tokens,
|
||||
@@ -90,7 +90,7 @@ __global__ void speculate_update_v3(int *seq_lens_encoder,
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
@@ -108,7 +108,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_update_v3<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
@@ -130,7 +130,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_v3)
|
||||
PD_BUILD_STATIC_OP(speculate_update)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
||||
@@ -152,4 +152,4 @@ PD_BUILD_STATIC_OP(speculate_update_v3)
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"draft_tokens", "draft_tokens_out"},
|
||||
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateUpdateV3));
|
||||
.SetKernelFn(PD_KERNEL(SpeculateUpdate));
|
@@ -1,55 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
__global__ void update_this_time(int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
int real_bsz,
|
||||
int value) {
|
||||
int linear_idx = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
|
||||
if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] != 0) {
|
||||
seq_lens_this_time[linear_idx] = value;
|
||||
} else if (seq_lens_encoder[linear_idx] == 0 &&
|
||||
seq_lens_decoder[linear_idx] == 0) {
|
||||
seq_lens_this_time[linear_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateThisTime(const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const int real_bsz,
|
||||
const int value) {
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
real_bsz,
|
||||
value);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_seq_lens_this_time)
|
||||
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
|
||||
.Outputs({"seq_lens_this_time_out"})
|
||||
.Attrs({"real_bsz: int", "value: int"})
|
||||
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateThisTime));
|
@@ -1,146 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_update(int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
bool *not_need_stop,
|
||||
int64_t *draft_tokens,
|
||||
int *actual_draft_token_nums,
|
||||
const int64_t *accept_tokens,
|
||||
const int *accept_num,
|
||||
const bool *stop_flags,
|
||||
const int *seq_lens_this_time,
|
||||
const bool *is_block_step,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens) {
|
||||
const int bid = threadIdx.x;
|
||||
const int accept_num_now = accept_num[bid];
|
||||
int stop_flag_now_int = 0;
|
||||
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
||||
if (stop_flags[bid]) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
if (seq_lens_encoder[bid] == 0) {
|
||||
seq_lens_decoder[bid] += accept_num_now;
|
||||
}
|
||||
|
||||
if (seq_lens_this_time[bid] > 1 &&
|
||||
seq_lens_encoder[bid] ==
|
||||
0) { // 对于append模式,需要根据接收与否确定是否要降低下次draft
|
||||
// token的数量
|
||||
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
|
||||
if (accept_num_now - 1 == current_actual_draft_token_num) {
|
||||
if (current_actual_draft_token_num + 2 <=
|
||||
max_draft_tokens - 1) {
|
||||
actual_draft_token_nums[bid] =
|
||||
current_actual_draft_token_num + 2;
|
||||
} else if (current_actual_draft_token_num + 1 <=
|
||||
max_draft_tokens - 1) {
|
||||
actual_draft_token_nums[bid] =
|
||||
current_actual_draft_token_num + 1;
|
||||
} else {
|
||||
actual_draft_token_nums[bid] = max_draft_tokens - 1;
|
||||
}
|
||||
} else {
|
||||
actual_draft_token_nums[bid] =
|
||||
actual_draft_token_nums[bid] - 1 >= 1
|
||||
? actual_draft_token_nums[bid] - 1
|
||||
: 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
seq_lens_decoder[bid] += seq_lens_encoder[bid];
|
||||
seq_lens_encoder[bid] = 0;
|
||||
}
|
||||
draft_tokens[bid * max_draft_tokens] =
|
||||
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
|
||||
if (stop_flag_now_int) {
|
||||
seq_lens_decoder[bid] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < real_bsz;
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateUpdateV2(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &is_block_step) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
auto max_draft_tokens = draft_tokens.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(actual_draft_token_nums.data<int>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
stop_flags.data<bool>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
is_block_step.data<bool>(),
|
||||
real_bsz,
|
||||
max_draft_tokens);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_update_v2)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
||||
"draft_tokens",
|
||||
"actual_draft_token_nums",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"is_block_step"})
|
||||
.Outputs({"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"not_need_stop_out",
|
||||
"draft_tokens_out",
|
||||
"actual_draft_token_nums_out"})
|
||||
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"draft_tokens", "draft_tokens_out"},
|
||||
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateUpdateV2));
|
154
custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h
Normal file
154
custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
|
||||
class SmemLayoutB, class SmemLayoutC>
|
||||
struct SharedStorage {
|
||||
union {
|
||||
struct {
|
||||
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
|
||||
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
|
||||
};
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
|
||||
};
|
||||
|
||||
struct {
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
|
||||
};
|
||||
};
|
||||
|
||||
template<int kBlockM_, int kBlockN_, int kBlockK_,
|
||||
int kNWarps_, int kStages_,
|
||||
int kTiles_, int M_,
|
||||
int TokenPackSize_,
|
||||
int TAIL_N_ = 0,
|
||||
int kClusterM_ = 1,
|
||||
typename elem_type=cutlass::float_e4m3_t,
|
||||
typename OutputType = cutlass::bfloat16_t>
|
||||
struct Kernel_traits {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using ElementOutput = OutputType;
|
||||
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||
|
||||
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kBlockK = kBlockK_;
|
||||
static constexpr int kTiles = kTiles_;
|
||||
static constexpr int TokenPackSize = TokenPackSize_;
|
||||
static constexpr int M = M_;
|
||||
static constexpr int TAIL_N = TAIL_N_;
|
||||
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
|
||||
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
|
||||
|
||||
static constexpr int kClusterM = kClusterM_;
|
||||
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
|
||||
|
||||
static constexpr int kStages = kStages_;
|
||||
static_assert(kStages > 1);
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using TiledMma_TAIL = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using SmemLayoutAtomA = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
|
||||
|
||||
using SmemLayoutA = decltype(
|
||||
tile_to_shape(SmemLayoutAtomA{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomB = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
using SmemLayoutB = decltype(
|
||||
tile_to_shape(SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomB_TAIL = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
|
||||
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
|
||||
|
||||
using SmemLayoutB_TAIL = decltype(
|
||||
tile_to_shape(SmemLayoutAtomB_TAIL{},
|
||||
make_shape(
|
||||
shape<1>(TileShape_MNK_TAIL{}),
|
||||
shape<2>(TileShape_MNK_TAIL{}),
|
||||
Int<kStages>{})
|
||||
));
|
||||
|
||||
using SmemLayoutAtomC = decltype(
|
||||
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||
GMMA::Major::K, ElementOutput,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||
|
||||
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
|
||||
|
||||
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
|
||||
|
||||
using SharedStorage = SharedStorage<
|
||||
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
|
||||
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||
|
||||
|
||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
|
||||
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
|
||||
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
|
||||
using TiledCopyCThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyCValLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyC = decltype(make_tiled_copy(
|
||||
TiledCopyCAtom{},
|
||||
TiledCopyCThrLayout{}, // Thr layout
|
||||
TiledCopyCValLayout{} // Val layout
|
||||
));
|
||||
};
|
405
custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h
Normal file
405
custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h
Normal file
@@ -0,0 +1,405 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
// #include "named_barrier.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
template <typename Ktraits>
|
||||
struct CollectiveMainloopFwd {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using ElementOutput = typename Ktraits::ElementOutput;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
using ElementAccum = typename Ktraits::ElementAccum;
|
||||
|
||||
static constexpr int kStages = Ktraits::kStages;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr int TAIL_N = Ktraits::TAIL_N;
|
||||
static constexpr int kBlockK = Ktraits::kBlockK;
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kTiles = Ktraits::kTiles;
|
||||
static constexpr int M = Ktraits::M;
|
||||
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
||||
|
||||
using GmemTiledCopy = cute::SM90_TMA_LOAD;
|
||||
|
||||
|
||||
using SmemLayoutA = typename Ktraits::SmemLayoutA;
|
||||
using SmemLayoutB = typename Ktraits::SmemLayoutB;
|
||||
using SmemLayoutC = typename Ktraits::SmemLayoutC;
|
||||
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
|
||||
|
||||
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
|
||||
using StrideT = cute::Shape<int64_t, _1, int64_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
ShapeT{},
|
||||
StrideT{}
|
||||
),
|
||||
SmemLayoutA{}(_, _, _0{}),
|
||||
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
||||
size<0>(ClusterShape{})));
|
||||
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
ShapeT{},
|
||||
StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutB{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})));
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
using SmemCopyAtomAB = typename Ktraits::SmemCopyAtomAB;
|
||||
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
|
||||
using TiledCopyC = typename Ktraits::TiledCopyC;
|
||||
|
||||
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
|
||||
struct Arguments {
|
||||
Element const* ptr_A;
|
||||
LayoutT layout_A;
|
||||
Element const* ptr_B;
|
||||
LayoutT layout_B;
|
||||
ElementOutput * ptr_C;
|
||||
LayoutT layout_C;
|
||||
const float *weight_scale;
|
||||
const float *input_row_sum;
|
||||
const int64_t * tokens;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
LayoutT layout_A;
|
||||
LayoutT layout_B;
|
||||
TMA_A tma_load_A;
|
||||
TMA_B tma_load_B;
|
||||
ElementOutput * ptr_C;
|
||||
const float *weight_scale;
|
||||
const float *input_row_sum;
|
||||
const int64_t * tokens;
|
||||
};
|
||||
|
||||
|
||||
Params static
|
||||
to_underlying_arguments(Arguments const& args) {
|
||||
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
|
||||
TMA_A tma_load_A = make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
mA,
|
||||
SmemLayoutA{}(_, _, _0{}),
|
||||
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
||||
size<0>(ClusterShape{}));
|
||||
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
|
||||
TMA_B tma_load_B = make_tma_copy(
|
||||
GmemTiledCopy{},
|
||||
mB,
|
||||
SmemLayoutB{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}));
|
||||
|
||||
return {args.layout_A, args.layout_B, tma_load_A, tma_load_B,
|
||||
args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
||||
CUTLASS_DEVICE void
|
||||
store(Params const& mainloop_params,
|
||||
FrgTensorO & tOrO,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
const float *input_row_sum,
|
||||
const float *weight_scale,
|
||||
const int64_t tokens,
|
||||
const int64_t pre_fix_tokens,
|
||||
const int bidm,
|
||||
const int bidn,
|
||||
const int bidb,
|
||||
const int tidx) {
|
||||
|
||||
using packHalf = typename PackedHalf<ElementOutput>::Type;
|
||||
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tOrO); i+=4) {
|
||||
const int sum_idx = i * 2;
|
||||
tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0];
|
||||
tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0];
|
||||
tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1];
|
||||
tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1];
|
||||
*reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]);
|
||||
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]);
|
||||
}
|
||||
|
||||
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
|
||||
|
||||
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||
|
||||
constexpr int k_copy_times = CUR_N / 16;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < k_copy_times; i++) {
|
||||
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
asm volatile (
|
||||
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
|
||||
#endif
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
|
||||
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
|
||||
|
||||
const int reamin_tokens = tokens - bidn * kBlockN;
|
||||
|
||||
const int col = tidx % 2;
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(ElementOutput);
|
||||
constexpr int kNumVecElem = kBlockM / kPackSize;
|
||||
constexpr int copy_len = CUR_N * kNumVecElem;
|
||||
#pragma unroll
|
||||
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
|
||||
const int idx_div2 = idx / 2;
|
||||
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
|
||||
const int store_global_idx = store_idx * 2 + col;
|
||||
const int row = store_global_idx / kNumVecElem;
|
||||
const int col = store_global_idx % kNumVecElem;
|
||||
if (row >= reamin_tokens) {
|
||||
continue;
|
||||
}
|
||||
const int offset = row * (M / kPackSize) + col;
|
||||
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename MTensor>
|
||||
CUTLASS_DEVICE auto get_local_no_packed_tensor(
|
||||
const MTensor &mB,
|
||||
const int pre_fix_token,
|
||||
const int actual_token,
|
||||
const int bidn) const {
|
||||
|
||||
auto g_offset = local_tile(
|
||||
mB(_, _, 0),
|
||||
cute::make_shape(1, size<1>(mB)),
|
||||
make_coord(pre_fix_token, _0{}));
|
||||
|
||||
auto g_tensor = make_tensor(
|
||||
g_offset.data(),
|
||||
make_layout(
|
||||
cute::make_shape(actual_token, size<2>(mB)),
|
||||
g_offset.stride()
|
||||
));
|
||||
|
||||
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||
|
||||
return gB;
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState& smem_pipe_write,
|
||||
SharedStorage &shared_storage,
|
||||
const int tokens,
|
||||
const int pre_fix_tokens,
|
||||
const int bidm,
|
||||
const int bidn,
|
||||
const int bidb,
|
||||
const int tidx) {
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
|
||||
|
||||
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
|
||||
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
|
||||
|
||||
Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), make_coord(bidm, _));
|
||||
|
||||
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
|
||||
|
||||
const int kIters = kTiles / kStages;
|
||||
|
||||
if constexpr (TokenPackSize == 0) {
|
||||
Tensor gB = get_local_no_packed_tensor(
|
||||
mB,
|
||||
pre_fix_tokens,
|
||||
tokens,
|
||||
bidn);
|
||||
|
||||
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
|
||||
|
||||
if (tidx == 0) {
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
const int i = kiter * kStages + s;
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = kIters * kStages; i < kTiles; ++i) {
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto mB_this_batch = make_tensor(
|
||||
mB(_, _, bidb).data(),
|
||||
make_layout(
|
||||
cute::make_shape(tokens, size<1>(mB)),
|
||||
mB.stride()
|
||||
));
|
||||
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
|
||||
|
||||
if (tidx == 0) {
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
const int i = kiter * kStages + s;
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = kIters * kStages; i < kTiles; ++i) {
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||
|
||||
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
||||
CUTLASS_DEVICE void
|
||||
mma(Params const& mainloop_params,
|
||||
TiledMma tiled_mma,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState& smem_pipe_read,
|
||||
SharedStorage& shared_storage,
|
||||
FrgTensorO &tSrS,
|
||||
const int tidx) {
|
||||
|
||||
using sMemBLayout = std::conditional_t<
|
||||
CUR_N == kBlockN,
|
||||
SmemLayoutB,
|
||||
SmemLayoutB_TAIL
|
||||
>;
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
|
||||
auto threadMma = tiled_mma.get_thread_slice(tidx);
|
||||
|
||||
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
|
||||
|
||||
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
|
||||
Tensor tSrB = threadMma.partition_fragment_B(sB);
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
const int kIters = kTiles / kStages;
|
||||
|
||||
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
|
||||
|
||||
#pragma unroll
|
||||
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kStages; s++) {
|
||||
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s));
|
||||
consumer_wait(pipeline, smem_pipe_read);
|
||||
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
|
||||
pipeline.consumer_release(smem_pipe_read);
|
||||
++smem_pipe_read;
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kTiles % kStages; ++i) {
|
||||
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i));
|
||||
consumer_wait(pipeline, smem_pipe_read);
|
||||
|
||||
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
|
||||
pipeline.consumer_release(smem_pipe_read);
|
||||
++smem_pipe_read;
|
||||
}
|
||||
}
|
||||
};
|
114
custom_ops/gpu_ops/w4afp8_gemm/utils.hpp
Normal file
114
custom_ops/gpu_ops/w4afp8_gemm/utils.hpp
Normal file
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename T>
|
||||
struct PackedHalf;
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::half_t> {
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::bfloat16_t> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <int numel>
|
||||
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numel; ++i) {
|
||||
dst1[i] = (src[i] >> 4) & 0x0f0f0f0f;
|
||||
dst2[i] = src[i] & 0x0f0f0f0f;
|
||||
}
|
||||
}
|
||||
|
||||
template <int wg_wait=0, bool arrive=true,
|
||||
bool commit=true, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename TiledMma,
|
||||
typename ThrCopyA, typename TiledCopyA>
|
||||
__forceinline__ __device__ void gemm(
|
||||
TiledMma &tiled_mma,
|
||||
Tensor0 &tCrA,
|
||||
Tensor1 &tCsA,
|
||||
Tensor2 const &tCrB,
|
||||
Tensor3 &tCrC,
|
||||
TiledCopyA const &tiled_copy_A,
|
||||
ThrCopyA const &thr_copy_A) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
|
||||
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
|
||||
|
||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
if (k_block < size<2>(tCrA) - 1) {
|
||||
cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1));
|
||||
}
|
||||
int32_t * tCrA_data = reinterpret_cast<int32_t *>(tCrA(_,_,k_block).data());
|
||||
int32_t * tCrA1_data = reinterpret_cast<int32_t *>(tCrA1(_,_,k_block).data());
|
||||
int32_t * tCrA2_data = reinterpret_cast<int32_t *>(tCrA2(_,_,k_block).data());
|
||||
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
|
||||
|
||||
cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC);
|
||||
cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC);
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
351
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu
Normal file
351
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu
Normal file
@@ -0,0 +1,351 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "w4afp8_gemm_template.h"
|
||||
#include "w4afp8_gemm.h"
|
||||
|
||||
|
||||
void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
|
||||
assert(K % 64 == 0);
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
for (int m = 0; m < M; ++m) {
|
||||
for (int k = 0; k < K; k+=64) {
|
||||
for (int k_inner = 0; k_inner < 32; ++k_inner) {
|
||||
uint8_t temp = 0;
|
||||
uint8_t left = weight[b * M * K + m * K + k + k_inner];
|
||||
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
|
||||
temp |= left << 4;
|
||||
temp |= right;
|
||||
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast<uint8_t*>(&temp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> class NVTraits;
|
||||
|
||||
template <> class NVTraits<__nv_fp8_e4m3> {
|
||||
public:
|
||||
typedef cutlass::float_e4m3_t data_t;
|
||||
};
|
||||
|
||||
template <> class NVTraits<__nv_bfloat16>{
|
||||
public:
|
||||
typedef cutlass::bfloat16_t data_t;
|
||||
};
|
||||
|
||||
template <> class NVTraits<half>{
|
||||
public:
|
||||
typedef cutlass::half_t data_t;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename OutputType>
|
||||
void DisPatchW4AFp8Gemm(
|
||||
const cutlass::float_e4m3_t* input,
|
||||
const cutlass::float_e4m3_t* weight,
|
||||
const int64_t * tokens,
|
||||
const float * input_row_sum,
|
||||
const float * weight_scale,
|
||||
OutputType * out,
|
||||
const int64_t token_padding_size,
|
||||
const int64_t max_tokens,
|
||||
const int batch_size,
|
||||
const int64_t M,
|
||||
const int64_t K,
|
||||
cudaStream_t stream) {
|
||||
|
||||
int kBlockN = (max_tokens + 15) / 16 * 16;
|
||||
int TailN = 0;
|
||||
if (kBlockN > 256) {
|
||||
TailN = kBlockN % 256;
|
||||
kBlockN = 256;
|
||||
}
|
||||
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
|
||||
GEMM_SWITCH_BF16(
|
||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||
weight,
|
||||
input,
|
||||
out,
|
||||
weight_scale,
|
||||
input_row_sum,
|
||||
tokens,
|
||||
max_tokens,
|
||||
stream)
|
||||
} else {
|
||||
GEMM_SWITCH_FP16(
|
||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||
weight,
|
||||
input,
|
||||
out,
|
||||
weight_scale,
|
||||
input_row_sum,
|
||||
tokens,
|
||||
max_tokens,
|
||||
stream)
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
||||
const paddle::Tensor& input_row_sum,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const int64_t token_padding_size,
|
||||
const int64_t max_tokens,
|
||||
const bool is_bfloat16) {
|
||||
|
||||
|
||||
const int batch_size = weight.dims()[0];
|
||||
const int M = weight.dims()[1];
|
||||
const int K = weight.dims()[2] * 2;
|
||||
|
||||
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
|
||||
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
|
||||
}
|
||||
|
||||
if (token_padding_size == 0) {
|
||||
const int all_tokens = input.dims()[0];
|
||||
if (is_bfloat16) {
|
||||
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
|
||||
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int64_t>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
} else {
|
||||
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
|
||||
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int64_t>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
}
|
||||
} else {
|
||||
if (is_bfloat16) {
|
||||
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
|
||||
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int64_t>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
} else {
|
||||
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
|
||||
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
|
||||
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int64_t>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType>
|
||||
void DisPatchW4AFp8GemmWrapper(
|
||||
const InputType* input,
|
||||
const InputType* weight,
|
||||
const int64_t* total_rows_before_expert,
|
||||
const float* input_row_sum,
|
||||
const float* row_scale,
|
||||
const float* weight_scale,
|
||||
OutputType * out,
|
||||
const int64_t token_padding_size,
|
||||
const int64_t max_tokens,
|
||||
const int num_experts,
|
||||
const int64_t M,
|
||||
const int64_t K,
|
||||
cudaStream_t stream) {
|
||||
using InType = typename NVTraits<InputType>::data_t;
|
||||
using OutType = typename NVTraits<OutputType>::data_t;
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const InType*>(input),
|
||||
reinterpret_cast<const InType*>(weight),
|
||||
total_rows_before_expert,
|
||||
input_row_sum,
|
||||
weight_scale,
|
||||
reinterpret_cast<OutType*>(out),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
num_experts,
|
||||
M,
|
||||
K,
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
|
||||
const int batch_size = weight.dims()[0];
|
||||
const int M = weight.dims()[1];
|
||||
const int K = weight.dims()[2];
|
||||
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place());
|
||||
weight_convert(weight.data<uint8_t>(), weight_new.data<uint8_t>(), batch_size, M, K);
|
||||
return {weight_new};
|
||||
}
|
||||
|
||||
template <typename T, int kPackSize>
|
||||
__global__ void permute_scale_kernel(
|
||||
T* input_data,
|
||||
const int numel) {
|
||||
using LoadT = AlignedVector<T, kPackSize>;
|
||||
LoadT input_vec;
|
||||
LoadT dst_vec;
|
||||
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
|
||||
if (load_idx >= numel) {
|
||||
return;
|
||||
}
|
||||
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
|
||||
|
||||
for (int i = 0; i < kPackSize; i+=2) {
|
||||
dst_vec[i] = input_vec[i / 2];
|
||||
dst_vec[i + 1] = input_vec[i / 2 + 8];
|
||||
}
|
||||
|
||||
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
|
||||
}
|
||||
|
||||
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
||||
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
|
||||
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
|
||||
if (col % 16 != 0) {
|
||||
PD_THROW("Only supported when col is divisible by 16.");
|
||||
}
|
||||
const int numel = row * col;
|
||||
const int threads = 128;
|
||||
const int kPackSize = 16;
|
||||
const int grid_size = (numel / kPackSize + threads - 1) / threads;
|
||||
|
||||
if (scale.dtype() == paddle::DataType::BFLOAT16) {
|
||||
permute_scale_kernel<phi::dtype::bfloat16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||
const_cast<phi::dtype::bfloat16*>(scale.data<phi::dtype::bfloat16>()),
|
||||
numel
|
||||
);
|
||||
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
|
||||
permute_scale_kernel<phi::dtype::float16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
|
||||
numel
|
||||
);
|
||||
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
|
||||
permute_scale_kernel<float, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||
const_cast<float*>(scale.data<float>()),
|
||||
numel
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute)
|
||||
.Inputs({"weight_scale"})
|
||||
.Outputs({"permute_scale"})
|
||||
.SetInplaceMap({{"weight_scale", "permute_scale"}})
|
||||
.SetKernelFn(PD_KERNEL(W4AFp8GemmScalePermute));
|
||||
|
||||
PD_BUILD_STATIC_OP(w4afp8_gemm)
|
||||
.Inputs({"input",
|
||||
"weight",
|
||||
"tokens",
|
||||
"input_row_sum",
|
||||
"weight_scale"})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"token_padding_size: int64_t",
|
||||
"max_tokens: int64_t",
|
||||
"is_bfloat16: bool"})
|
||||
.SetKernelFn(PD_KERNEL(W4AFp8Gemm));
|
||||
|
||||
PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert)
|
||||
.Inputs({"weight"})
|
||||
.Outputs({"converted_weight"})
|
||||
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));
|
||||
|
||||
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>(
|
||||
const __nv_fp8_e4m3* input,
|
||||
const __nv_fp8_e4m3* weight,
|
||||
const int64_t * tokens,
|
||||
const float * input_row_sum,
|
||||
const float * row_scale,
|
||||
const float * weight_scale,
|
||||
__nv_bfloat16 * out,
|
||||
const int64_t token_padding_size,
|
||||
const int64_t max_tokens,
|
||||
const int num_experts,
|
||||
const int64_t M,
|
||||
const int64_t K,
|
||||
cudaStream_t stream
|
||||
);
|
||||
|
||||
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>(
|
||||
const __nv_fp8_e4m3* input,
|
||||
const __nv_fp8_e4m3* weight,
|
||||
const int64_t * tokens,
|
||||
const float * input_row_sum,
|
||||
const float * row_scale,
|
||||
const float * weight_scale,
|
||||
half * out,
|
||||
const int64_t token_padding_size,
|
||||
const int64_t max_tokens,
|
||||
const int num_experts,
|
||||
const int64_t M,
|
||||
const int64_t K,
|
||||
cudaStream_t stream
|
||||
);
|
47
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h
Normal file
47
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
||||
const paddle::Tensor& input_row_sum,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const int64_t token_padding_size,
|
||||
const int64_t max_tokens,
|
||||
const bool is_bfloat16);
|
||||
|
||||
template <typename InputType, typename OutputType>
|
||||
void DisPatchW4AFp8GemmWrapper(
|
||||
const InputType* input,
|
||||
const InputType* weight,
|
||||
const int64_t * tokens,
|
||||
const float * input_row_sum,
|
||||
const float * row_scale,
|
||||
const float * weight_scale,
|
||||
OutputType * out,
|
||||
const int64_t token_padding_size,
|
||||
const int64_t max_tokens,
|
||||
const int num_experts,
|
||||
const int64_t M,
|
||||
const int64_t K,
|
||||
cudaStream_t stream);
|
252
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp
Normal file
252
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp
Normal file
@@ -0,0 +1,252 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
|
||||
#include "kernel_traits.h"
|
||||
#include "mainloop_fwd.h"
|
||||
|
||||
template <typename Ktraits>
|
||||
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_gemm_kernel(
|
||||
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
||||
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int M = Ktraits::M;
|
||||
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
||||
static constexpr int TAIL_N = Ktraits::TAIL_N;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
|
||||
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
using ElementOutput = typename Ktraits::ElementOutput;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
const int bidm = blockIdx.x;
|
||||
const int bidn = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
if (tidx == 0) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NumMmaThreads;
|
||||
|
||||
MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
} else {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
const int pre_fix_tokens = TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1]) : 0;
|
||||
|
||||
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] - pre_fix_tokens : mainloop_params.tokens[bidb];
|
||||
|
||||
|
||||
if (bidn * kBlockN >= tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
float* input_row_sum = reinterpret_cast<float*>(
|
||||
shared_memory + sizeof(typename Ktraits::SharedStorage));
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
|
||||
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
collective_mainloop.load(
|
||||
mainloop_params,
|
||||
pipeline,
|
||||
smem_pipe_write,
|
||||
shared_storage,
|
||||
tokens,
|
||||
pre_fix_tokens,
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
tidx);
|
||||
} else {
|
||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
|
||||
PipelineState smem_pipe_read;
|
||||
|
||||
typename Ktraits::TiledMma tiled_mma;
|
||||
|
||||
typename Ktraits::TiledMma_TAIL tiled_mma_tail;
|
||||
|
||||
const int mma_tidx = tidx - NumCopyThreads;
|
||||
const int lane_id = mma_tidx % 4 * 2;
|
||||
|
||||
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
|
||||
|
||||
if constexpr (TokenPackSize == 0) {
|
||||
const int input_sum_idx = pre_fix_tokens + bidn * kBlockN;
|
||||
if (mma_tidx < kBlockN) {
|
||||
reinterpret_cast<float*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
|
||||
}
|
||||
} else {
|
||||
const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN;
|
||||
if (mma_tidx < kBlockN / 4) {
|
||||
reinterpret_cast<float4*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float4*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
|
||||
}
|
||||
}
|
||||
|
||||
const int reamin_tokens = tokens - bidn * kBlockN;
|
||||
|
||||
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
|
||||
Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{}));
|
||||
collective_mainloop.mma<TAIL_N>(
|
||||
mainloop_params,
|
||||
tiled_mma_tail,
|
||||
pipeline,
|
||||
smem_pipe_read,
|
||||
shared_storage,
|
||||
tSrS_tail,
|
||||
mma_tidx);
|
||||
collective_mainloop.store<TAIL_N>(
|
||||
mainloop_params,
|
||||
tSrS_tail,
|
||||
shared_storage,
|
||||
tiled_mma_tail,
|
||||
input_row_sum + lane_id,
|
||||
reinterpret_cast<const float*>(&weight_scale),
|
||||
tokens,
|
||||
pre_fix_tokens,
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
mma_tidx);
|
||||
} else {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
|
||||
collective_mainloop.mma<kBlockN>(
|
||||
mainloop_params,
|
||||
tiled_mma,
|
||||
pipeline,
|
||||
smem_pipe_read,
|
||||
shared_storage,
|
||||
tSrS,
|
||||
mma_tidx);
|
||||
collective_mainloop.store<kBlockN>(
|
||||
mainloop_params,
|
||||
tSrS,
|
||||
shared_storage,
|
||||
tiled_mma,
|
||||
input_row_sum + lane_id,
|
||||
reinterpret_cast<const float*>(&weight_scale),
|
||||
tokens,
|
||||
pre_fix_tokens,
|
||||
bidm,
|
||||
bidn,
|
||||
bidb,
|
||||
mma_tidx);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <int Batch>
|
||||
auto get_gmem_layout(const int Rows, const int Cols) {
|
||||
return make_layout(
|
||||
make_shape(
|
||||
static_cast<int64_t>(Rows),
|
||||
static_cast<int64_t>(Cols),
|
||||
static_cast<int64_t>(Batch)),
|
||||
make_stride(
|
||||
static_cast<int64_t>(Cols),
|
||||
cute::_1{},
|
||||
static_cast<int64_t>(Rows * Cols)));
|
||||
}
|
||||
|
||||
|
||||
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
|
||||
void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
|
||||
const float *input_row_sum, const int64_t * tokens, const int64_t max_tokens, cudaStream_t stream) {
|
||||
|
||||
using ElementOutput = typename Kernel_traits::ElementOutput;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
|
||||
constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments({
|
||||
static_cast<Element const*>(A),
|
||||
get_gmem_layout<Batch>(M, K / 2),
|
||||
static_cast<Element const*>(B),
|
||||
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens * Batch : TokenPackSize, K),
|
||||
static_cast<ElementOutput*>(C),
|
||||
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
|
||||
weight_scale,
|
||||
input_row_sum,
|
||||
tokens
|
||||
});
|
||||
|
||||
void *kernel;
|
||||
kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>;
|
||||
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = M_nums;
|
||||
grid_dims.y = N_nums;
|
||||
grid_dims.z = Batch;
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(
|
||||
launch_params, kernel, mainloop_params);
|
||||
}
|
@@ -293,6 +293,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||
"gpu_ops/noaux_tc.cu",
|
||||
"gpu_ops/custom_all_reduce/all_reduce.cu",
|
||||
"gpu_ops/limit_content_len.cu",
|
||||
]
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -494,6 +495,8 @@ elif paddle.is_compiled_with_cuda():
|
||||
if cc >= 90 and nvcc_version >= 12.0:
|
||||
# Hopper optmized mla
|
||||
sources += find_end_files("gpu_ops/mla_attn", ".cu")
|
||||
os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
|
||||
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
|
||||
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
|
207
custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py
Normal file
207
custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
file_dir = "./gpu_ops/w4afp8_gemm/"
|
||||
|
||||
gemm_template_head = """
|
||||
#pragma once
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
"""
|
||||
gemm_template_case = """
|
||||
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
const cutlass::float_e4m3_t * weight,
|
||||
const cutlass::float_e4m3_t * input,
|
||||
{cutlass_type} * out,
|
||||
const float *weight_scale,
|
||||
const float *input_row_sum,
|
||||
const int64_t *tokens,
|
||||
const int64_t max_tokens,
|
||||
cudaStream_t stream);
|
||||
"""
|
||||
|
||||
gemm_template_cu_head = """
|
||||
#include "paddle/extension.h"
|
||||
#include "w4afp8_gemm_template.h"
|
||||
#include "w4afp8_gemm_kernel.hpp"
|
||||
|
||||
"""
|
||||
gemm_template_cu_template = """
|
||||
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
const cutlass::float_e4m3_t * weight,
|
||||
const cutlass::float_e4m3_t * input,
|
||||
{cutlass_type} * out,
|
||||
const float *weight_scale,
|
||||
const float *input_row_sum,
|
||||
const int64_t *tokens,
|
||||
const int64_t max_tokens,
|
||||
cudaStream_t stream) {{
|
||||
|
||||
constexpr static int M = {M};
|
||||
constexpr static int K = {K};
|
||||
constexpr static int Batch = {BATCH};
|
||||
constexpr static int TokenPackSize = {PADDING};
|
||||
constexpr static int kBlockN = {N};
|
||||
constexpr static int kBlockN_TAIL = {TAILN};
|
||||
constexpr static int kBlockM = 128;
|
||||
constexpr static int kBlockK = 128;
|
||||
constexpr static int kNWarps = 4 + kBlockM / 16;
|
||||
constexpr static int kStages = 5;
|
||||
constexpr int kCluster = 1;
|
||||
static_assert(K % kBlockK == 0);
|
||||
constexpr int kTiles = K / kBlockK;
|
||||
|
||||
using Kernel_traits = Kernel_traits<
|
||||
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
|
||||
M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
|
||||
{cutlass_type}>;
|
||||
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
|
||||
Kernel_traits, M, K, Batch, TokenPackSize>
|
||||
(weight, input, out, weight_scale,
|
||||
input_row_sum, tokens, max_tokens, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
gemm_case = [
|
||||
[8192, 3584, 8, 0], # eb45T ffn1
|
||||
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||
[7168, 8192, 8, 0], # eb45T ffn2
|
||||
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||
]
|
||||
|
||||
dtype = ["BF16", "FP16"]
|
||||
|
||||
|
||||
def get_cutlass_type(type):
|
||||
if type == "BF16":
|
||||
return "cutlass::bfloat16_t"
|
||||
elif type == "FP16":
|
||||
return "cutlass::half_t"
|
||||
|
||||
|
||||
template_head_file = open(f"{file_dir}w4afp8_gemm_template.h", "w")
|
||||
template_head_file.write(gemm_template_head)
|
||||
|
||||
for type in dtype:
|
||||
for case in gemm_case:
|
||||
for n in range(16, 257, 16):
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=n,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=0,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=256,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=n - 16,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file = open(
|
||||
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
|
||||
)
|
||||
template_cu_file.write(gemm_template_cu_head)
|
||||
template_cu_file.write(
|
||||
gemm_template_cu_template.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=n,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=0,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file.close()
|
||||
|
||||
template_cu_file = open(
|
||||
f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
|
||||
)
|
||||
template_cu_file.write(gemm_template_cu_head)
|
||||
template_cu_file.write(
|
||||
gemm_template_cu_template.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=256,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=n - 16,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file.close()
|
||||
|
||||
for type in dtype:
|
||||
template_head_file.write("\n")
|
||||
template_head_file.write(
|
||||
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
|
||||
if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
|
||||
TYPE=type
|
||||
)
|
||||
)
|
||||
|
||||
template_head_file.write("\n")
|
||||
|
||||
for case in gemm_case:
|
||||
for n in range(16, 257, 16):
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
|
||||
)
|
||||
)
|
||||
template_head_file.write("\n")
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16
|
||||
)
|
||||
)
|
||||
template_head_file.write("\n")
|
||||
|
||||
template_head_file.write(
|
||||
""" } else { \\
|
||||
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
|
||||
} \\
|
||||
}"""
|
||||
)
|
||||
|
||||
template_head_file.close()
|
@@ -294,16 +294,24 @@ class SpeculativeConfig:
|
||||
self,
|
||||
args,
|
||||
):
|
||||
# speculative method, choose in [None, "ngram_match", "mtp"]
|
||||
self.method_list = ["ngram_match", "mtp"]
|
||||
self.mtp_strategy_list = ["default", "with_ngram"]
|
||||
|
||||
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
|
||||
self.method: Optional[str] = None
|
||||
# mtp strategy in mtp-method
|
||||
self.mtp_strategy = "default"
|
||||
# the max length of speculative tokens
|
||||
self.num_speculative_tokens: int = 1
|
||||
# the model runner step of draft model/mtp...
|
||||
self.num_model_steps: int = 1
|
||||
# the max length of candidate tokens for speculative method
|
||||
self.max_candidate_len: int = 5
|
||||
# the max length of verify window for speculative method
|
||||
self.verify_window: int = 2
|
||||
# ngram match
|
||||
self.max_ngram_size: int = 5
|
||||
self.min_ngram_size: int = 2
|
||||
# model for mtp/eagle/draft_model
|
||||
self.model: Optional[str] = None
|
||||
# quantization of model
|
||||
@@ -390,6 +398,33 @@ class SpeculativeConfig:
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info("=============================================================")
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
) -> None:
|
||||
"""Check the legality of parameters passed in from the command line"""
|
||||
if self.method is not None:
|
||||
assert (
|
||||
self.method in self.method_list
|
||||
), f"speculative method only support {self.method_list} now, but get {self.method}."
|
||||
|
||||
assert (
|
||||
self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5
|
||||
), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}."
|
||||
assert (
|
||||
self.num_model_steps >= 1 and self.num_model_steps <= 5
|
||||
), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}."
|
||||
|
||||
if self.method in ["mtp", "hybrid_mtp_ngram"]:
|
||||
if self.num_speculative_tokens < self.num_model_steps:
|
||||
logger.warning(
|
||||
f"Get num_model_steps > num_speculative_tokens. Reset num_speculative_tokens to {self.num_model_steps}"
|
||||
)
|
||||
self.num_speculative_tokens = self.num_model_steps
|
||||
|
||||
assert (
|
||||
self.mtp_strategy in self.mtp_strategy_list
|
||||
), f"mtp_strategy_list only support {self.mtp_strategy_list}, but get {self.mtp_strategy}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
|
@@ -283,6 +283,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.kv_token_num_cpu[0].item(),
|
||||
self.max_seq_len,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
self.rope_3d,
|
||||
)
|
||||
|
||||
res = self.flash_attn_func(
|
||||
|
@@ -49,6 +49,7 @@ def gqa_rope_write_cache(
|
||||
kv_token_num: int = 1,
|
||||
max_seq_len: int = 0,
|
||||
cache_quant_type: str = "none",
|
||||
rope_3d: bool = False,
|
||||
):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
|
||||
@@ -81,6 +82,7 @@ def gqa_rope_write_cache(
|
||||
kv_token_num,
|
||||
max_seq_len,
|
||||
cache_quant_type,
|
||||
rope_3d,
|
||||
)
|
||||
return q, k, v, qkv_
|
||||
else:
|
||||
|
@@ -46,7 +46,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
assert self.quant_method.name() == "wint8"
|
||||
|
@@ -51,7 +51,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
Paddle gcu create weight process.
|
||||
"""
|
||||
# bf16
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
||||
@@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
@@ -312,7 +312,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
||||
|
@@ -12,13 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .fused_moe_cutlass_backend import CutlassW4A8MoEMethod, CutlassWeightOnlyMoEMethod
|
||||
from .fused_moe_cutlass_backend import (
|
||||
CutlassW4A8MoEMethod,
|
||||
CutlassW4AFP8MoEMethod,
|
||||
CutlassWeightOnlyMoEMethod,
|
||||
)
|
||||
from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod
|
||||
from .moe import FusedMoE
|
||||
|
||||
__all__ = [
|
||||
CutlassWeightOnlyMoEMethod,
|
||||
CutlassW4A8MoEMethod,
|
||||
CutlassW4AFP8MoEMethod,
|
||||
FusedMoE,
|
||||
TritonWeightOnlyMoEMethod,
|
||||
]
|
||||
|
@@ -355,7 +355,7 @@ class EPPrefillRunner(EPRunner):
|
||||
):
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
_,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
_,
|
||||
@@ -365,6 +365,7 @@ class EPPrefillRunner(EPRunner):
|
||||
dispatch_args = {
|
||||
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": self.ep_engine.ep_config,
|
||||
|
@@ -31,6 +31,7 @@ if current_platform.is_cuda():
|
||||
moe_expert_dispatch,
|
||||
moe_expert_reduce,
|
||||
noaux_tc,
|
||||
w4afp8_gemm_scale_permute,
|
||||
)
|
||||
elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
@@ -75,7 +76,9 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
# bf16
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
||||
@@ -98,6 +101,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
token_nums_per_expert: paddle.Tensor,
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
estimate_total_token_nums: int = -1,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -115,6 +119,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
expert_idx_per_token,
|
||||
self.moe_quant_type,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
)
|
||||
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
||||
permute_input,
|
||||
@@ -128,6 +133,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
expert_idx_per_token,
|
||||
self.moe_quant_type,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
)
|
||||
|
||||
def apply_ep_prefill(
|
||||
@@ -167,13 +173,13 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
(self.up_gate_proj_in_scale if hasattr(self, "up_gate_proj_in_scale") else None),
|
||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
|
||||
recv_num_tokens_per_expert_list,
|
||||
token_all_num,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
if self.moe_quant_type != "w4a8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 and w4afp8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
@@ -211,15 +217,17 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts")
|
||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||
use_fp8 = self.moe_quant_type == "w4afp8"
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale
|
||||
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale, use_fp8=use_fp8
|
||||
)
|
||||
# 3. Compute ffn
|
||||
if self.moe_quant_type == "w4a8":
|
||||
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
|
||||
num_local_experts, max_num, _ = permute_input.shape
|
||||
expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile([1, max_num])
|
||||
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
|
||||
@@ -233,6 +241,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
token_nums_per_expert.cast("int64"),
|
||||
expert_idx_per_token,
|
||||
True,
|
||||
estimate_total_token_nums,
|
||||
)
|
||||
|
||||
# 4. EP combine
|
||||
@@ -295,7 +304,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if self.moe_quant_type != "w4a8":
|
||||
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
expert_idx_per_token = None
|
||||
@@ -332,7 +341,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
self.moe_quant_type = "w4a8"
|
||||
self.pack_num = 2
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
@@ -343,10 +352,10 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||
)
|
||||
)
|
||||
|
||||
up_gate_proj_weight_scale = []
|
||||
@@ -355,22 +364,62 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
up_gate_proj_in_scale = []
|
||||
down_proj_in_scale = []
|
||||
|
||||
if isinstance(state_dict, list):
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in range(layer.num_experts):
|
||||
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||
|
||||
for expert_idx in logical_expert_ids:
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
|
||||
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_weight_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
up_gate_proj_in_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
|
||||
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
@@ -396,7 +445,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
@@ -407,9 +458,13 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
self.create_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
|
||||
self.create_w4a8_scale_weights(
|
||||
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
)
|
||||
|
||||
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
|
||||
def create_w4a8_scale_weights(
|
||||
self, layer: nn.Layer, weight_key_map: dict, state_dict: dict, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
):
|
||||
"""
|
||||
Get w4a8 weights from state dict and process them.
|
||||
Args:
|
||||
@@ -418,8 +473,15 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
state_dict (dict): The state dict.
|
||||
"""
|
||||
|
||||
def _extract_scale_tensor(state_dict, key_template, expert_idx):
|
||||
return get_tensor(state_dict.pop(key_template.format(expert_idx)))
|
||||
def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx):
|
||||
return get_tensor(
|
||||
(
|
||||
state_dict.pop(key_template.format(expert_idx))
|
||||
if key_template.format(expert_idx) in state_dict
|
||||
else key_template.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
|
||||
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
||||
processed_in_scale = 1 / paddle.concat(in_scales)
|
||||
@@ -461,17 +523,249 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
|
||||
# 2. Extract scale tensor from state dict
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in range(layer.num_experts):
|
||||
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]
|
||||
if scale_key_map["up_gate_proj_in_scale"].format(expert_idx) in state_dict
|
||||
else scale_key_map["up_gate_proj_in_scale"].format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
|
||||
create_and_set_parameter(
|
||||
layer, "up_gate_proj_in_scale_all_experts", paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||
)
|
||||
|
||||
for local_expert_idx in range(layer.num_local_experts):
|
||||
expert_idx = local_expert_idx + layer.expert_id_offset
|
||||
for expert_idx in logical_expert_ids:
|
||||
for name, scale_key_template in scale_key_map.items():
|
||||
scale_tensor = _extract_scale_tensor(state_dict, scale_key_template, expert_idx)
|
||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
|
||||
# 3. Process scale tensor and set to layer
|
||||
in_scales = []
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name]))
|
||||
|
||||
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
||||
_process_weight_scale(
|
||||
weight_scale_name,
|
||||
scale_weight_map[weight_scale_name],
|
||||
in_scales[i],
|
||||
)
|
||||
|
||||
|
||||
class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
w4a8 MoE Method
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = "w4afp8"
|
||||
self.pack_num = 2
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||
)
|
||||
)
|
||||
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
up_gate_proj_in_scale_all_experts = []
|
||||
up_gate_proj_in_scale = []
|
||||
down_proj_in_scale = []
|
||||
|
||||
if isinstance(state_dict, list):
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||
|
||||
for expert_idx in logical_expert_ids:
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
|
||||
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_weight_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
up_gate_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
|
||||
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
down_proj_in_scale.append(
|
||||
get_tensor(
|
||||
(
|
||||
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
|
||||
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||
else down_proj_expert_in_scale_key.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0)
|
||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0)
|
||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
|
||||
"up_gate_proj_in_scale": up_gate_proj_in_scale,
|
||||
"down_proj_in_scale": down_proj_in_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
weight_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
|
||||
weight_list.append(quant_weight)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
self.create_w4afp8_scale_weights(
|
||||
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
)
|
||||
|
||||
def create_w4afp8_scale_weights(
|
||||
self, layer: nn.Layer, weight_key_map: dict, state_dict: dict, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
):
|
||||
"""
|
||||
Get w4a8 weights from state dict and process them.
|
||||
Args:
|
||||
layer (nn.Layer): The layer to add parameters to.
|
||||
weight_key_map (dict): The weight key map.
|
||||
state_dict (dict): The state dict.
|
||||
"""
|
||||
|
||||
def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx):
|
||||
return get_tensor(
|
||||
(
|
||||
state_dict.pop(key_template.format(expert_idx))
|
||||
if key_template.format(expert_idx) in state_dict
|
||||
else key_template.format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
|
||||
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
||||
processed_in_scale = 1 / paddle.concat(in_scales)
|
||||
create_and_set_parameter(layer, name, processed_in_scale)
|
||||
return processed_in_scale
|
||||
|
||||
def _permute_weight_scale(weight_scale: paddle.Tensor):
|
||||
weight_scale = w4afp8_gemm_scale_permute(weight_scale)
|
||||
return weight_scale
|
||||
|
||||
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
|
||||
processed_weight_scale = (
|
||||
paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None]
|
||||
)
|
||||
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
||||
create_and_set_parameter(layer, name, processed_weight_scale)
|
||||
|
||||
# 1. Init scale containers and maps
|
||||
up_gate_proj_weight_scales = []
|
||||
down_proj_weight_scales = []
|
||||
up_gate_proj_in_scales_all_experts = []
|
||||
up_gate_proj_in_scales = []
|
||||
down_proj_in_scales = []
|
||||
|
||||
scale_weight_map = {
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scales,
|
||||
"down_proj_weight_scale": down_proj_weight_scales,
|
||||
"up_gate_proj_in_scale": up_gate_proj_in_scales,
|
||||
"down_proj_in_scale": down_proj_in_scales,
|
||||
}
|
||||
scale_key_map = {
|
||||
"up_gate_proj_weight_scale": weight_key_map.get("up_gate_proj_expert_weight_scale_key", None),
|
||||
"down_proj_weight_scale": weight_key_map.get("down_proj_expert_weight_scale_key", None),
|
||||
"up_gate_proj_in_scale": weight_key_map.get("up_gate_proj_expert_in_scale_key", None),
|
||||
"down_proj_in_scale": weight_key_map.get("down_proj_expert_in_scale_key", None),
|
||||
}
|
||||
for name, value in scale_key_map.items():
|
||||
if value is None:
|
||||
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
|
||||
|
||||
# 2. Extract scale tensor from state dict
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(
|
||||
(
|
||||
state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]
|
||||
if scale_key_map["up_gate_proj_in_scale"].format(expert_idx) in state_dict
|
||||
else scale_key_map["up_gate_proj_in_scale"].format(expert_idx)
|
||||
),
|
||||
layer.fd_config.model_config.model,
|
||||
)
|
||||
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
|
||||
create_and_set_parameter(
|
||||
layer, "up_gate_proj_in_scale_all_experts", paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||
)
|
||||
|
||||
for expert_idx in logical_expert_ids:
|
||||
for name, scale_key_template in scale_key_map.items():
|
||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
|
||||
# 3. Process scale tensor and set to layer
|
||||
@@ -498,7 +792,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
self.pack_num = 1
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
@@ -507,7 +801,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
@@ -541,7 +835,9 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
|
@@ -37,7 +37,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
deepgemm create weight process.
|
||||
"""
|
||||
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
@@ -62,7 +64,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
@@ -71,10 +73,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||
)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
up_gate_proj_weight_scale = []
|
||||
|
@@ -143,7 +143,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Marlin MoE create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
|
@@ -56,7 +56,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
|
||||
@@ -267,7 +267,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
|
||||
up_gate_proj_tensor, down_proj_tensor = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_tensor, down_proj_tensor, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert up_gate_proj_tensor[0].shape == [
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
@@ -534,7 +534,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
|
@@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
|
||||
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
|
@@ -36,7 +36,7 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
# bf16
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
for weights in [up_gate_proj_weights, down_proj_weights]:
|
||||
for idx, weight in enumerate(weights):
|
||||
weights[idx] = weight.transpose([1, 0])
|
||||
@@ -130,7 +130,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
|
@@ -63,6 +63,7 @@ class FusedMoE(nn.Layer):
|
||||
routed_scaling_factor: float = 1.0,
|
||||
layer_idx: int = -1,
|
||||
moe_tag: str = "",
|
||||
redundant_table_manger: RedundantExpertManger = None,
|
||||
weight_key_map: dict = {},
|
||||
):
|
||||
"""
|
||||
@@ -118,15 +119,8 @@ class FusedMoE(nn.Layer):
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
self.quant_method = get_moe_method()
|
||||
|
||||
self.redundant_table_manger = None
|
||||
self.redundant_table_manger = redundant_table_manger
|
||||
if self.ep_size > 1:
|
||||
if fd_config.model_config.enable_redundant_experts is True:
|
||||
self.redundant_table_manger = RedundantExpertManger(
|
||||
n_routed_experts=fd_config.model_config.moe_num_experts,
|
||||
num_hidden_layers=fd_config.model_config.num_hidden_layers,
|
||||
redundant_experts_num=fd_config.model_config.redundant_experts_num,
|
||||
ep_size=self.ep_size,
|
||||
)
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
@@ -223,6 +217,7 @@ class FusedMoE(nn.Layer):
|
||||
state_dict: dict,
|
||||
up_gate_proj_expert_weight_key: str,
|
||||
down_proj_expert_weight_key: str,
|
||||
is_rearrange: bool = False,
|
||||
):
|
||||
"""
|
||||
Load experts weight from state_dict.
|
||||
@@ -238,6 +233,7 @@ class FusedMoE(nn.Layer):
|
||||
self.expert_id_offset + self.num_local_experts,
|
||||
)
|
||||
]
|
||||
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
|
||||
if self.redundant_table_manger is not None:
|
||||
(
|
||||
ep_rank_to_expert_id_list,
|
||||
@@ -250,7 +246,13 @@ class FusedMoE(nn.Layer):
|
||||
]
|
||||
up_gate_proj_weights = []
|
||||
down_proj_weights = []
|
||||
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
|
||||
|
||||
if isinstance(state_dict, list):
|
||||
state_dict = dict(state_dict)
|
||||
is_ffn_merged = (
|
||||
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
|
||||
in state_dict
|
||||
)
|
||||
if is_ffn_merged:
|
||||
for expert_idx in logical_expert_ids:
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
@@ -309,7 +311,7 @@ class FusedMoE(nn.Layer):
|
||||
self.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
return up_gate_proj_weights, down_proj_weights, logical_expert_ids
|
||||
return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
|
||||
def extract_moe_ffn_weights(self, state_dict: dict):
|
||||
"""
|
||||
@@ -332,10 +334,12 @@ class FusedMoE(nn.Layer):
|
||||
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
||||
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
self.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
)
|
||||
)
|
||||
assert (
|
||||
len(up_gate_proj_weights) == self.num_local_experts
|
||||
@@ -344,7 +348,7 @@ class FusedMoE(nn.Layer):
|
||||
len(down_proj_weights) == self.num_local_experts
|
||||
), "down_proj_weights length should be equal to num_local_experts."
|
||||
|
||||
return up_gate_proj_weights, down_proj_weights
|
||||
return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
|
||||
def extract_gate_correction_bias(self, gate_correction_bias_key, state_dict):
|
||||
"""
|
||||
@@ -386,7 +390,7 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", False):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
|
||||
else:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
|
@@ -36,7 +36,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
image_moe_quant_type: str = None,
|
||||
is_channel_wise: bool = False,
|
||||
has_zero_point: bool = False,
|
||||
is_permuted: bool = False,
|
||||
is_permuted: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
@@ -65,7 +65,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
config.get("image_moe_quant_type", None),
|
||||
config.get("is_channel_wise", False),
|
||||
config.get("has_zero_point", False),
|
||||
config.get("is_permuted", False),
|
||||
config.get("is_permuted", True),
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -73,13 +73,13 @@ class MixQuantConfig(QuantConfigBase):
|
||||
if layer.moe_tag == "Image":
|
||||
return (
|
||||
get_quantization_config(self.image_moe_quant_type)
|
||||
.from_config(layer.fd_config.quant_config)
|
||||
.from_config({"is_permuted": self.is_permuted})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
get_quantization_config(self.moe_quant_type)
|
||||
.from_config(layer.fd_config.quant_config)
|
||||
.from_config({"is_permuted": self.is_permuted})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
elif isinstance(layer, Attention):
|
||||
|
@@ -34,7 +34,7 @@ class W4A8Config(QuantConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "W4A8Config":
|
||||
is_permuted = getattr(config, "is_permuted", False)
|
||||
is_permuted = config.get("is_permuted", True)
|
||||
return cls(is_permuted)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
|
@@ -20,6 +20,7 @@ import paddle
|
||||
|
||||
import fastdeploy
|
||||
|
||||
from ..moe import FusedMoE
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
QUANT_SCALING_FACTOR = 448
|
||||
@@ -30,24 +31,32 @@ class W4AFP8Config(QuantConfigBase):
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.is_permuted = is_permuted
|
||||
|
||||
def name(self) -> str:
|
||||
return "w4afp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "W4AFP8Config":
|
||||
weight_scale_dict = config["weight_scale_dict"]
|
||||
act_scale_dict = config["act_scale_dict"]
|
||||
return cls(weight_scale_dict, act_scale_dict)
|
||||
weight_scale_dict = config.get("weight_scale_dict", None)
|
||||
act_scale_dict = config.get("act_scale_dict", None)
|
||||
is_permuted = config.get("is_permuted", True)
|
||||
return cls(weight_scale_dict, act_scale_dict, is_permuted)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
|
||||
CutlassW4AFP8MoEMethod,
|
||||
)
|
||||
|
||||
return CutlassW4AFP8MoEMethod(self)
|
||||
return W4AFP8LinearMethod(self)
|
||||
|
||||
|
||||
|
@@ -33,6 +33,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
|
||||
min_p_sampling,
|
||||
top_k_top_p_sampling,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import limit_content_len
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||
|
||||
@@ -304,6 +305,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
||||
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
||||
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
|
||||
self.fd_config = fd_config
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
@@ -382,6 +384,22 @@ class SpeculativeSampler(nn.Layer):
|
||||
self.speculative_benchmark_mode,
|
||||
)
|
||||
|
||||
if hasattr(self.fd_config.model_config, "think_end_id") and self.fd_config.model_config.think_end_id > 0:
|
||||
limit_content_len(
|
||||
share_inputs["accept_tokens"],
|
||||
self.fd_config.model_config.think_end_id,
|
||||
share_inputs["max_content_len"],
|
||||
share_inputs["max_think_len"],
|
||||
share_inputs["step_idx"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["max_dec_len"],
|
||||
share_inputs["limit_content_status"],
|
||||
share_inputs["enable_thinking"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["stop_flags"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -429,8 +447,8 @@ class MTPSampler(nn.Layer):
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
@@ -112,7 +112,11 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
||||
num_local_ffn_keys.append(down_proj_in_scale_key)
|
||||
|
||||
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
|
||||
for j in range(fd_config.model_config.moe_num_experts):
|
||||
num_experts = fd_config.model_config.moe_num_experts
|
||||
if isinstance(num_experts, list):
|
||||
num_experts = num_experts[0]
|
||||
|
||||
for j in range(num_experts):
|
||||
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
|
||||
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
|
||||
|
||||
|
@@ -46,6 +46,7 @@ from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
|
||||
class Ernie4_5_MLP(nn.Layer):
|
||||
@@ -94,13 +95,15 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
|
||||
|
||||
class Ernie4_5_MoE(nn.Layer):
|
||||
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
|
||||
def __init__(
|
||||
self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
moe_quant_type = ""
|
||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||
|
||||
if moe_quant_type == "w4a8":
|
||||
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
|
||||
weight_key_map = {
|
||||
"gate_weight_key": f"{prefix}.gate.weight",
|
||||
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
||||
@@ -154,6 +157,7 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
top_k=fd_config.model_config.moe_k,
|
||||
layer_idx=layer_id,
|
||||
weight_key_map=weight_key_map,
|
||||
redundant_table_manger=redundant_table_manger,
|
||||
)
|
||||
|
||||
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
|
||||
@@ -170,6 +174,9 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
if self.num_shared_experts > 0:
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
|
||||
def update_state_dict(self, state_dict):
|
||||
self.fused_moe.load_state_dict(state_dict, True)
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor):
|
||||
out = self.fused_moe(hidden_states)
|
||||
if self.num_shared_experts > 0:
|
||||
@@ -226,6 +233,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
redundant_table_manger: RedundantExpertManger = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -244,6 +252,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
self.mlp = Ernie4_5_MoE(
|
||||
fd_config=fd_config,
|
||||
layer_id=layer_id,
|
||||
redundant_table_manger=redundant_table_manger,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
@@ -273,6 +282,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
self.input_layernorm.load_state_dict(state_dict)
|
||||
self.post_attention_layernorm.load_state_dict(state_dict)
|
||||
|
||||
def update_state_dict(self, state_dict):
|
||||
self.mlp.update_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
forward_meta: ForwardMeta,
|
||||
@@ -313,6 +325,16 @@ class Ernie4_5_Model(nn.Layer):
|
||||
|
||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
||||
self.fd_config = fd_config
|
||||
|
||||
self.redundant_table_manger = None
|
||||
if fd_config.model_config.enable_redundant_experts is True:
|
||||
self.redundant_table_manger = RedundantExpertManger(
|
||||
n_routed_experts=fd_config.model_config.moe_num_experts,
|
||||
num_hidden_layers=fd_config.model_config.num_hidden_layers,
|
||||
redundant_experts_num=fd_config.model_config.redundant_experts_num,
|
||||
ep_size=fd_config.parallel_config.expert_parallel_size,
|
||||
)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
fd_config=fd_config,
|
||||
@@ -326,6 +348,7 @@ class Ernie4_5_Model(nn.Layer):
|
||||
[
|
||||
Ernie4_5_DecoderLayer(
|
||||
fd_config=fd_config,
|
||||
redundant_table_manger=self.redundant_table_manger,
|
||||
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
@@ -354,6 +377,22 @@ class Ernie4_5_Model(nn.Layer):
|
||||
logger.info(f"Start load layer {i}")
|
||||
self.layers[i].load_state_dict(state_dict)
|
||||
|
||||
def update_state_dict(self, state_dict):
|
||||
"""
|
||||
Update model parameters from a given state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||
A dictionary containing model parameters, where keys are parameter names
|
||||
and values are NumPy arrays or PaddlePaddle tensors.
|
||||
"""
|
||||
for i in range(
|
||||
self.fd_config.model_config.moe_layer_start_index,
|
||||
self.fd_config.model_config.num_hidden_layers,
|
||||
):
|
||||
logger.info(f"Start update layer {i}")
|
||||
self.layers[i].update_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
|
@@ -244,6 +244,7 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
|
||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
|
||||
self.norm = fd_config.speculative_config.sharing_model.ernie.norm
|
||||
|
||||
self.layers = nn.LayerList(
|
||||
[
|
||||
@@ -314,6 +315,8 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@@ -46,7 +46,6 @@ from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
extract_text_token_output,
|
||||
text_image_gather_scatter,
|
||||
text_image_index_out,
|
||||
)
|
||||
@@ -99,8 +98,8 @@ class Ernie4_5_VLMoE(nn.Layer):
|
||||
assert text_moe_layer_start_index <= text_moe_layer_end_index
|
||||
|
||||
moe_quant_type = ""
|
||||
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
|
||||
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
|
||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||
|
||||
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
|
||||
if moe_quant_type == "tensor_wise_fp8" or (
|
||||
@@ -472,26 +471,6 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# -----------------------
|
||||
hidden_states = hidden_states.cast("float32")
|
||||
score_text = hidden_states
|
||||
|
||||
if image_input is not None:
|
||||
token_type_ids = token_type_ids.reshape([-1])
|
||||
text_pos_shifted = token_type_ids[:token_num] == 0
|
||||
score_text = hidden_states[text_pos_shifted.reshape([-1])]
|
||||
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time.squeeze(-1), k=1)
|
||||
hidden_states = extract_text_token_output(
|
||||
max_seq_len,
|
||||
max_seq_len_index.cast("int32"),
|
||||
image_token_num.cast("int32"),
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
score_text,
|
||||
).cast(self._dtype)
|
||||
# -----------------------
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
|
||||
return out
|
||||
|
@@ -59,7 +59,7 @@ else:
|
||||
speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle,
|
||||
speculate_step_system_cache,
|
||||
speculate_update_v3,
|
||||
speculate_update,
|
||||
step_paddle,
|
||||
step_system_cache,
|
||||
update_inputs,
|
||||
@@ -288,7 +288,7 @@ def post_process_normal(
|
||||
|
||||
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
|
||||
""""""
|
||||
speculate_update_v3(
|
||||
speculate_update(
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.not_need_stop,
|
||||
|
@@ -281,12 +281,13 @@ class TokenProcessor:
|
||||
|
||||
def _compute_speculative_status(self):
|
||||
# TODO(liuzichang): Supplement more statistics
|
||||
interval = 50
|
||||
interval = 10
|
||||
if self.speculative_stats_step % interval == 0:
|
||||
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
|
||||
spec_logger.info(
|
||||
f"Speculate global accept ratio(Accept draft_tokens/Generated tokens): {accept_ratio}"
|
||||
f" total step: {self.total_step}. total output token num: {self.number_of_output_tokens}"
|
||||
f" avarage accept len: {self.number_of_output_tokens / self.total_step}"
|
||||
)
|
||||
|
||||
if self.cfg.speculative_config.method in ["mtp"]:
|
||||
|
@@ -45,6 +45,10 @@ class Proposer(ABC):
|
||||
self.max_model_len = self.parallel_config.max_model_len
|
||||
self.speculative_method = self.speculative_config.method
|
||||
self.max_draft_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.num_model_steps = self.speculative_config.num_model_steps
|
||||
|
||||
self.max_ngram_size = self.speculative_config.max_ngram_size
|
||||
self.min_ngram_size = self.speculative_config.min_ngram_size
|
||||
|
||||
spec_logger.info(f"Speculate config: {self.speculative_config}")
|
||||
|
||||
|
@@ -35,6 +35,7 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
hybrid_mtp_ngram,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
@@ -57,6 +58,8 @@ class MTPProposer(Proposer):
|
||||
self._update_cfg(main_model)
|
||||
self._load_model()
|
||||
self.main_model_inputs = main_model_inputs
|
||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = "mixed"
|
||||
@@ -266,12 +269,19 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
self.model_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
position_ids=tmp_position_ids,
|
||||
base=self.model_config.rope_theta,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
if len(self.main_model_inputs["rope_emb"].shape) == 5:
|
||||
self.model_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
position_ids=tmp_position_ids,
|
||||
base=self.model_config.rope_theta,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
else:
|
||||
self.model_inputs["max_content_len"] = paddle.clone(self.main_model_inputs["max_content_len"])
|
||||
self.model_inputs["max_think_len"] = paddle.clone(self.main_model_inputs["max_think_len"])
|
||||
self.model_inputs["limit_content_status"] = paddle.clone(self.main_model_inputs["limit_content_status"])
|
||||
self.model_inputs["enable_thinking"] = paddle.clone(self.main_model_inputs["enable_thinking"])
|
||||
self.model_inputs["rope_emb"] = paddle.clone(self.main_model_inputs["rope_emb"])
|
||||
# self.model_inputs["caches"] = self.cache_kvs
|
||||
# Inherit generation hyperparameters from the main model for consistency
|
||||
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
|
||||
@@ -291,9 +301,12 @@ class MTPProposer(Proposer):
|
||||
# Integrate the updated results in model forward
|
||||
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
||||
self.model_inputs["substep"] = 0
|
||||
self.max_num_seqs = self.main_model_inputs["draft_tokens"].shape[0]
|
||||
|
||||
# Input tokens
|
||||
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
||||
self.model_inputs["draft_tokens"] = paddle.full(
|
||||
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
|
||||
)
|
||||
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
|
||||
|
||||
@@ -311,10 +324,11 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||
if self.max_draft_token_num > 1:
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.full_like(
|
||||
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
)
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||
"""
|
||||
@@ -339,6 +353,7 @@ class MTPProposer(Proposer):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = len(request.prompt_token_ids)
|
||||
self.input_ids_len[idx] = length
|
||||
|
||||
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
|
||||
length = len(request.prompt_token_ids)
|
||||
@@ -432,15 +447,17 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.main_model_inputs["accept_tokens"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.main_model_inputs["seq_lens_decoder"],
|
||||
self.main_model_inputs["step_idx"],
|
||||
self.main_model_inputs["stop_flags"],
|
||||
self.main_model_inputs["is_block_step"],
|
||||
self.main_model_inputs["draft_tokens"],
|
||||
self.max_draft_token_num,
|
||||
self.num_model_steps,
|
||||
self.speculative_method in ["eagle", "mtp"],
|
||||
self.role == "prefill",
|
||||
)
|
||||
@@ -454,7 +471,7 @@ class MTPProposer(Proposer):
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.max_draft_token_num,
|
||||
self.num_model_steps,
|
||||
)
|
||||
if isinstance(target_hidden_states, list):
|
||||
target_hidden_states = target_hidden_states[0]
|
||||
@@ -494,7 +511,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Main process for MTP inference
|
||||
"""
|
||||
for substep in range(self.max_draft_token_num):
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
@@ -514,6 +531,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
)
|
||||
|
||||
# Initialize forward meta data
|
||||
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
self.model_inputs["cum_offsets"].copy_(cum_offsets, False)
|
||||
@@ -540,7 +558,7 @@ class MTPProposer(Proposer):
|
||||
eos_token_ids=self.model_inputs["eos_token_id"],
|
||||
)
|
||||
|
||||
if self.max_draft_token_num > 1:
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
||||
|
||||
model_output = self.model(
|
||||
@@ -574,7 +592,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
|
||||
if substep != self.max_draft_token_num - 1:
|
||||
if substep != self.num_model_steps - 1:
|
||||
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
||||
|
||||
def _get_self_hidden_states(self, hidden_states):
|
||||
@@ -646,11 +664,37 @@ class MTPProposer(Proposer):
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
|
||||
def _extend_draft_token_with_ngram_match(self):
|
||||
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
|
||||
device = paddle.CUDAPinnedPlace()
|
||||
|
||||
draft_tokens = self.main_model_inputs["draft_tokens"].cpu()
|
||||
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
|
||||
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
|
||||
hybrid_mtp_ngram(
|
||||
self.model_inputs["input_ids"]._copy_to(device, True),
|
||||
self.input_ids_len,
|
||||
self.model_inputs["pre_ids"]._copy_to(device, True),
|
||||
self.model_inputs["step_idx"].cpu(),
|
||||
self.main_model_inputs["actual_draft_token_num"].cpu(),
|
||||
draft_tokens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
self.model_inputs["max_dec_len"].cpu(),
|
||||
self.max_ngram_size,
|
||||
self.min_ngram_size,
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(self, full_hidden_states):
|
||||
""""""
|
||||
target_hidden_states = self._prepare_inputs(full_hidden_states)
|
||||
self._propose(target_hidden_states=target_hidden_states)
|
||||
self._update_status()
|
||||
if self.hybrid_mode:
|
||||
self._extend_draft_token_with_ngram_match()
|
||||
|
||||
def is_chunk_prefill_enabled(self):
|
||||
""""""
|
||||
|
@@ -1210,21 +1210,20 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["image_features"],
|
||||
self.forward_meta,
|
||||
)
|
||||
hidden_states = model_output
|
||||
else:
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
hidden_states = rebuild_padding(
|
||||
model_output,
|
||||
self.share_inputs["cum_offsets"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.share_inputs["seq_lens_encoder"],
|
||||
(self.share_inputs["output_padding_offset"] if self.speculative_decoding else None),
|
||||
self.parallel_config.max_model_len,
|
||||
)
|
||||
hidden_states = rebuild_padding(
|
||||
model_output,
|
||||
self.share_inputs["cum_offsets"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.share_inputs["seq_lens_encoder"],
|
||||
(self.share_inputs["output_padding_offset"] if self.speculative_decoding else None),
|
||||
self.parallel_config.max_model_len,
|
||||
)
|
||||
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
|
193
scripts/offline_w4a8.py
Normal file
193
scripts/offline_w4a8.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
import paddle
|
||||
from paddleformers.trainer import strtobool
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
from paddleformers.transformers.model_utils import shard_checkpoint
|
||||
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
from paddleformers.utils.log import logger
|
||||
from safetensors.numpy import save_file as safe_save_file
|
||||
|
||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.load_weight_utils import (
|
||||
get_all_safetensors,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""
|
||||
parse_arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="The directory of model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default="merged_output",
|
||||
required=True,
|
||||
help="The directory of merged model output.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
type=strtobool,
|
||||
default="True",
|
||||
help="Whether merge the model into safetensors format.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--moe_quant_type",
|
||||
default="w4a8",
|
||||
choices=["w4a8", "w4afp8"],
|
||||
help="The moe quant type of the model.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def reorder():
|
||||
def fn(weight, moe_quant_type):
|
||||
from paddle.nn.quant import weight_quantize
|
||||
|
||||
quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80)
|
||||
return quant_weight.cpu()
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def deal_in_scale():
|
||||
def fn(in_scale):
|
||||
processed_in_scale = 1 / in_scale
|
||||
return processed_in_scale
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def deal_weight_scale():
|
||||
def fn(weight_scale, processed_in_scale, moe_quant_type):
|
||||
if moe_quant_type == "w4a8":
|
||||
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
|
||||
return processed_weight_scale
|
||||
elif moe_quant_type == "w4afp8":
|
||||
processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale
|
||||
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda())
|
||||
return processed_weight_scale
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
# tmp support w4a8
|
||||
def deal_quant(state_dict, save_state_dict, moe_quant_type):
|
||||
param_mapping = [
|
||||
# pattern,fn
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
|
||||
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
|
||||
]
|
||||
for pattern, fn in param_mapping:
|
||||
for key in list(state_dict.keys()):
|
||||
# print(f"deal {key}")
|
||||
match = re.search(pattern, key)
|
||||
if match:
|
||||
# print(f"{key} is match")
|
||||
weight_or_scale = state_dict.pop(key)
|
||||
if "weight_scale" in key:
|
||||
in_scale_key = key.replace("weight_scale", "activation_scale")
|
||||
in_scale = save_state_dict[in_scale_key]
|
||||
save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type)
|
||||
elif "activation_scale" in key:
|
||||
save_state_dict[key] = fn(weight_or_scale)
|
||||
else:
|
||||
save_state_dict[key] = fn(weight_or_scale, moe_quant_type)
|
||||
|
||||
|
||||
def save_safetensors(state_dict, args):
|
||||
"""
|
||||
save_safetensors
|
||||
"""
|
||||
logger.info("Move to numpy.")
|
||||
for k in list(state_dict.keys()):
|
||||
if isinstance(state_dict[k], paddle.Tensor):
|
||||
state_dict[k] = state_dict.pop(k).cpu().numpy()
|
||||
|
||||
logger.info("Save safetensors files.")
|
||||
shards, index = shard_checkpoint(
|
||||
state_dict,
|
||||
max_shard_size="5GB",
|
||||
weights_name=SAFE_WEIGHTS_NAME,
|
||||
shard_format="naive",
|
||||
)
|
||||
for shard_file, shard in shards.items():
|
||||
save_file = os.path.join(args.output_dir, shard_file)
|
||||
logger.info(f"Saving {save_file}")
|
||||
safe_save_file(shard, save_file, metadata={"format": "np"})
|
||||
|
||||
save_index_file = os.path.join(args.output_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2) + "\n"
|
||||
f.write(content)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
main
|
||||
"""
|
||||
args = parse_arguments()
|
||||
pretrained_config, _ = PretrainedConfig.get_config_dict(args.model_name_or_path)
|
||||
pretrained_config = PretrainedConfig.from_dict(pretrained_config)
|
||||
vocab_file_names = [
|
||||
"tokenizer.model",
|
||||
"spm.model",
|
||||
"ernie_token_100k.model",
|
||||
]
|
||||
for i in range(len(vocab_file_names)):
|
||||
if os.path.exists(os.path.join(args.model_name_or_path, vocab_file_names[i])):
|
||||
ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
|
||||
break
|
||||
tokenizer = ErnieBotTokenizer.from_pretrained(args.model_name_or_path)
|
||||
_, safetensor_files = get_all_safetensors(args.model_name_or_path)
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
state_dict = {}
|
||||
save_state_dict = {}
|
||||
start = time.perf_counter()
|
||||
for k, v in weights_iterator:
|
||||
state_dict[k] = get_tensor(v).cpu()
|
||||
end = time.perf_counter()
|
||||
logger.info("Finish Quantize.")
|
||||
logger.info(f"load and quantize took : {end - start:.6f} seconds")
|
||||
deal_quant(state_dict, save_state_dict, args.moe_quant_type)
|
||||
for key in list(state_dict.keys()):
|
||||
save_state_dict[key] = state_dict.pop(key)
|
||||
logger.info("Begin to save model")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
start = time.perf_counter()
|
||||
if not args.safe_serialization:
|
||||
paddle.save(
|
||||
save_state_dict,
|
||||
os.path.join(args.output_dir, "model_state.pdparams"),
|
||||
)
|
||||
else:
|
||||
save_safetensors(save_state_dict, args)
|
||||
pretrained_config.is_permuted = True
|
||||
pretrained_config.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
end = time.perf_counter()
|
||||
logger.info(f"save model took: {end - start:.6f} seconds")
|
||||
logger.info("Finish.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
36
scripts/run_offline_w4a8.sh
Normal file
36
scripts/run_offline_w4a8.sh
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
set -ex
|
||||
rm -rf log
|
||||
rm -f core*
|
||||
|
||||
export devices=0
|
||||
export CUDA_VISIBLE_DEVICES=${devices}
|
||||
model_path=${1:-"/PATH/MODEL_PATH"}
|
||||
output_path=${2:-"/PATH/OUTPUT_MODEL"}
|
||||
moe_quant_type=${3:-"w4a8"}
|
||||
for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
|
||||
unset ${name}
|
||||
done
|
||||
export PADDLE_TRAINER_ID=0
|
||||
export PADDLE_TRAINERS_NUM=1
|
||||
export TRAINER_INSTANCES_NUM=1
|
||||
export TRAINER_INSTANCES=`hostname -i`
|
||||
self_ip=`hostname -i`
|
||||
|
||||
python offline_w4a8.py \
|
||||
--model_name_or_path ${model_path} \
|
||||
--output_dir ${output_path} \
|
||||
--safe_serialization "True" \
|
||||
--moe_quant_type ${moe_quant_type}
|
75
test/operators/test_hybrid_mtp_ngram.py
Normal file
75
test/operators/test_hybrid_mtp_ngram.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import hybrid_mtp_ngram
|
||||
|
||||
|
||||
class TestNgramMatchMixed(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.max_bsz = 2
|
||||
self.max_draft_tokens = 5
|
||||
self.max_len = 32
|
||||
self.max_dec_len = 10
|
||||
self.max_ngram_size = 5
|
||||
self.min_ngram_size = 2
|
||||
|
||||
# 初始化输入 tensor
|
||||
self.input_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu()
|
||||
self.input_ids_len = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu()
|
||||
self.pre_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu()
|
||||
self.step_idx = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu()
|
||||
self.draft_token_num = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
|
||||
self.draft_tokens = paddle.full(
|
||||
shape=[self.max_bsz, self.max_draft_tokens + 1],
|
||||
fill_value=-1,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
self.seq_lens_this_time = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
|
||||
self.seq_lens_decoder = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
|
||||
self.max_dec_len = paddle.full(
|
||||
shape=[self.max_bsz, 1],
|
||||
fill_value=self.max_dec_len,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
|
||||
# 设置具体数据
|
||||
self.input_ids[:, :10] = np.arange(0, 10)
|
||||
self.input_ids_len[:] = 10
|
||||
pre_ids_np = np.array([10, 9, 8, 7, 6, 10, 9, 8, 7], dtype="int32")
|
||||
self.pre_ids[:, : pre_ids_np.shape[0]] = pre_ids_np
|
||||
self.step_idx[:] = 8
|
||||
|
||||
self.draft_token_num[:] = 5
|
||||
self.draft_tokens[:, :2] = np.array([8, 7])
|
||||
self.seq_lens_this_time[:] = 2
|
||||
self.seq_lens_decoder[:] = 12
|
||||
self.max_dec_len[:] = 512
|
||||
|
||||
# 期望结果
|
||||
self.ref_seq_lens_this_time = np.array([[6], [6]], dtype="int32")
|
||||
self.ref_draft_tokens = np.array([[8, 7, 6, 10, 9, 8], [8, 7, 6, 10, 9, 8]], dtype="int64")
|
||||
|
||||
def test_ngram_match_mixed(self):
|
||||
hybrid_mtp_ngram(
|
||||
self.input_ids,
|
||||
self.input_ids_len,
|
||||
self.pre_ids,
|
||||
self.step_idx,
|
||||
self.draft_token_num,
|
||||
self.draft_tokens,
|
||||
self.seq_lens_this_time,
|
||||
self.seq_lens_decoder,
|
||||
self.max_dec_len,
|
||||
self.max_ngram_size,
|
||||
self.min_ngram_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(self.seq_lens_this_time.numpy(), self.ref_seq_lens_this_time)
|
||||
np.testing.assert_allclose(self.draft_tokens.numpy(), self.ref_draft_tokens)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
103
test/operators/test_w4afp8_gemm.py
Normal file
103
test/operators/test_w4afp8_gemm.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert
|
||||
|
||||
|
||||
def w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N):
|
||||
all_tokens = int(tokens.sum())
|
||||
out = paddle.zeros([all_tokens, N], dtype="bfloat16")
|
||||
pre_fix_token = 0
|
||||
for i in range(BATCH):
|
||||
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
|
||||
weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
|
||||
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
|
||||
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
|
||||
pre_fix_token += tokens[i]
|
||||
return out
|
||||
|
||||
|
||||
def permute_scale(weight_scale):
|
||||
weight_scale = weight_scale.reshape([BATCH, N])
|
||||
temp = paddle.zeros([16])
|
||||
for b in range(BATCH):
|
||||
for n in range(0, N, 16):
|
||||
temp[:] = weight_scale[b, n : n + 16]
|
||||
for j in range(0, 16, 2):
|
||||
weight_scale[b, n + j] = temp[j // 2]
|
||||
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
|
||||
return weight_scale
|
||||
|
||||
|
||||
paddle.seed(0)
|
||||
tokens_per_group = 32
|
||||
N = 8192
|
||||
K = 3584
|
||||
BATCH = 8
|
||||
TokenPadding = 0
|
||||
|
||||
tokens = [tokens_per_group] * BATCH
|
||||
tokens_perfix_sum = np.cumsum(tokens)
|
||||
|
||||
|
||||
tokens = paddle.to_tensor(tokens, dtype="int64")
|
||||
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int64")
|
||||
|
||||
all_tokens = int(tokens.sum())
|
||||
|
||||
input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
|
||||
input_bf16 = input_fp8.astype("bfloat16")
|
||||
weight = paddle.randn([BATCH, N, K], dtype="bfloat16") / 10
|
||||
|
||||
weight_scale = 7 / weight.abs().max(axis=-1).reshape([BATCH, N, 1])
|
||||
weight_quant = (weight * weight_scale).astype("int") + 7
|
||||
weight_quant = paddle.clip(weight_quant, 0, 14)
|
||||
weight_quant = weight_quant.astype("bfloat16")
|
||||
weight_dequant_scale = 1 / weight_scale.astype("float32")
|
||||
input_row_sum = input_bf16.sum(axis=1) * -7 / 512
|
||||
max_tokens = int(tokens.max())
|
||||
|
||||
out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N)
|
||||
weight_dequant_scale = paddle.to_tensor(permute_scale(weight_dequant_scale) * 512)
|
||||
|
||||
weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu())
|
||||
|
||||
if TokenPadding == 0:
|
||||
out_cuda = w4afp8_gemm(
|
||||
input_fp8,
|
||||
weight_int4.cuda(),
|
||||
tokens_perfix_sum,
|
||||
input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
int(TokenPadding),
|
||||
max_tokens,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
out_cuda = w4afp8_gemm(
|
||||
input_fp8,
|
||||
weight_int4.cuda(),
|
||||
tokens,
|
||||
input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
int(TokenPadding),
|
||||
max_tokens,
|
||||
True,
|
||||
)
|
||||
|
||||
gap = (out_cuda - out_naive).abs()
|
||||
assert float(gap.mean()) < 0.07
|
Reference in New Issue
Block a user