Compare commits

...

26 Commits

Author SHA1 Message Date
xiaoxiaohehe001
9f1882d9a8 fa3_rope (#4190) 2025-09-21 22:04:59 +08:00
xiaoxiaohehe001
5223065d59 support mm mtp (#4013) 2025-09-09 13:55:45 +08:00
freeliuzc
c753f1fc9e [Feature][MTP]Support new mtp (#3656)
* update multi-draft-token strategy

* fix format

* support hybrid mtp with ngram speculative decoding method
2025-08-27 19:38:26 +08:00
Yuan Xiaolan
62659a7a73 support w4afp8 offline quant (#3438) 2025-08-15 17:32:12 +08:00
xiaoxiaohehe001
4f17f9aa6e add w4a8 online quant eplb 2025-08-15 12:54:08 +08:00
xiaoxiaohehe001
7642611b12 Merge branch 'feature/online/45T_20250730' of https://github.com/PaddlePaddle/FastDeploy into feature/online/45T_20250730 2025-08-14 00:49:28 +08:00
Yuan Xiaolan
2513cd929b support w4afp8 EP inference (#3382) 2025-08-13 21:41:34 +08:00
xiaoxiaohehe001
4dbaa3d74c Fix w4a8 scale load (#3334)
* fix_eplb

* fix eplb part3

* support_fp8_rope3d

* fix w4a8 scale
2025-08-11 21:02:42 +08:00
xiaoxiaohehe001
44043e0c88 fix w4a8 scale 2025-08-11 20:57:55 +08:00
xiaoxiaohehe001
7b8db880b7 Merge branch 'PaddlePaddle:feature/online/45T_20250730' into feature/online/45T_20250730 2025-08-11 20:54:32 +08:00
yangjianfengo1
c7993d35cb 支持w4afp8 (#3324) 2025-08-11 19:00:18 +08:00
xiaoxiaohehe001
c7cb31051b [Fix] support_fp8_rope3d (#3278)
* support_fp8_rope3d
2025-08-08 19:38:21 +08:00
xiaoxiaohehe001
5e7ab3dfe3 Merge branch 'PaddlePaddle:feature/online/45T_20250730' into feature/online/45T_20250730 2025-08-08 19:35:38 +08:00
xiaoxiaohehe001
abed681444 support_fp8_rope3d 2025-08-08 19:34:06 +08:00
xiaoxiaohehe001
548f53e433 Feature/online/45 t 20250730 (#3276)
* fix_eplb

* fix eplb part3
2025-08-08 19:30:59 +08:00
xiaoxiaohehe001
ee742f55f1 Merge branch 'PaddlePaddle:feature/online/45T_20250730' into feature/online/45T_20250730 2025-08-08 19:30:04 +08:00
xiaoxiaohehe001
794ab9705f Fix eplb part3 (#3206)
* fix_eplb

* fix eplb part3
2025-08-05 10:58:17 +08:00
xiaoxiaohehe001
0e0891ad12 fix eplb part3 2025-08-05 10:53:00 +08:00
xiaoxiaohehe001
9e87f3341b Merge branch 'PaddlePaddle:feature/online/45T_20250730' into feature/online/45T_20250730 2025-08-05 10:44:57 +08:00
xiaoxiaohehe001
869626b0f4 fix_eplb (#3160) 2025-08-03 01:50:07 +08:00
xiaoxiaohehe001
1b1287e145 fix_eplb 2025-08-03 01:49:10 +08:00
freeliuzc
9307f2619b 【Fix】【MTP】fix mtp bug (#3140) 2025-08-01 15:45:00 +08:00
carryyu
fbe03866d1 fix eplb part 1 2025-07-31 17:11:48 +08:00
Yuan Xiaolan
89ad20bea2 fix w4a8 scale (#3115) 2025-07-31 16:50:06 +08:00
Yuan Xiaolan
02398135a8 fix is_permuted (#3100) 2025-07-30 22:35:22 +08:00
Yuan Xiaolan
d65a0a6a2c support W4A8 EPLB (#3075) (#3094) 2025-07-30 19:46:42 +08:00
66 changed files with 3608 additions and 786 deletions

4
.gitignore vendored
View File

@@ -164,3 +164,7 @@ build
.ccls-cache
third_party
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h

View File

@@ -286,6 +286,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
@@ -309,6 +310,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,

View File

@@ -199,8 +199,9 @@ __global__ void append_decode_cache_T_rope_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -512,7 +513,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads) {
const int kv_num_heads,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -555,8 +557,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -633,10 +636,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
const T *cache_v_scale_cur = cache_v_scale + v_head_idx * HeadDim + head_bias;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
if constexpr (!is_scale_channel_wise) {
scale = __ldg(&cache_k_scale[kv_head_idx]);
}
@@ -763,7 +767,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads) {
const int kv_num_heads,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -813,9 +818,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -908,6 +913,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
const T *cache_v_scale_cur = cache_v_scales + v_head_idx * HeadDim + head_bias;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);

View File

@@ -248,7 +248,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -271,7 +272,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
}
}
}

View File

@@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim) {
const int last_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
@@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
Load<T, VecSize>(&qkv[base_idx], &src_vec);
// do rope
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const bool rope_3d,
const cudaStream_t &stream) {
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int PackSize = 16 / sizeof(T);
@@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head);
dim_head,
rope_3d);
}
template <typename T,
@@ -890,7 +894,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const std::string& cache_quant_type) {
const std::string& cache_quant_type,
const bool rope_3d) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -953,8 +958,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.dims()[2],
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
if (token_num < kv_token_num) {

View File

@@ -193,7 +193,8 @@ __global__ void append_speculate_cache_rope_kernel(
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size) {
const int gqa_group_size,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -253,8 +254,9 @@ __global__ void append_speculate_cache_rope_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -476,7 +478,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size) {
const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -522,8 +525,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
}
@@ -583,10 +587,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T scale;
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
} else {
scale = __ldg(&cache_v_scales[kv_head_idx]);

View File

@@ -39,7 +39,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style) {
const bool use_neox_style,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
@@ -96,7 +97,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads);
kv_num_heads,
rope_3d);
}
}
@@ -125,7 +127,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style) {
const bool use_neox_style,
const bool rope_3d) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -191,7 +194,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
}
}
@@ -313,6 +317,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
@@ -368,7 +373,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz,
token_nums,
stream,
use_neox_rotary_style);
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -401,7 +407,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz,
token_nums,
stream,
use_neox_rotary_style);
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -434,7 +441,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz,
token_nums,
stream,
use_neox_rotary_style);
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -500,6 +508,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
@@ -526,6 +535,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
@@ -551,6 +561,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
@@ -578,6 +589,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,

View File

@@ -35,6 +35,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,

View File

@@ -107,7 +107,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const int kv_token_num, const int max_seq_len,
const std::string &cache_quant_type);
const std::string &cache_quant_type,
const bool rope_3d);
std::vector<paddle::Tensor>
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
@@ -188,7 +189,8 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency);
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
@@ -334,6 +336,19 @@ void TextImageIndexOut(const paddle::Tensor &token_type_ids,
const paddle::Tensor &text_input,
const paddle::Tensor &image_input);
void LimitContentLen(const paddle::Tensor& next_tokens,
const paddle::Tensor& end_thinking_tokens,
const paddle::Tensor& max_content_len,
const paddle::Tensor& max_think_len,
const paddle::Tensor& step_idx,
const paddle::Tensor& eos_token_ids,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& limit_content_status,
const paddle::Tensor& enable_thinking,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags);
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor &image_input,
paddle::Tensor &token_type_ids,
@@ -603,7 +618,7 @@ void SpeculateVerify(
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
@@ -648,6 +663,20 @@ void NgramMatch(const paddle::Tensor &input_ids,
const int max_draft_tokens);
void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens);
// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -664,8 +693,10 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
@@ -1082,7 +1113,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
@@ -1092,6 +1123,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");

View File

@@ -193,6 +193,12 @@ public:
typedef uint8_t data_t;
};
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
public:
typedef __nv_fp8_e4m3 DataType;
typedef paddle::float8_e4m3fn data_t;
};
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];

View File

@@ -0,0 +1,186 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
__global__ void limit_content_len(
int64_t* next_tokens,
const int64_t* end_thinking_tokens,
int* max_content_lens,
const int* max_think_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int64_t* max_dec_lens,
int* limit_content_status,
const bool* enable_thinking,
int* accept_num,
int* seq_lens_decoder,
bool* stop_flags,
const int tokens_per_step,
const int bs,
const int end_thinking_token_num,
const int eos_token_id_len) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= bs) return;
if (!enable_thinking[idx]) return;
const int original_accept_num = accept_num[idx];
if (original_accept_num <= 0) return;
int current_limit_content_status = limit_content_status[idx];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_content_status == 2 && stop_flags[idx]) {
return;
}
const int max_think_len_reg = max_think_lens[idx];
const int64_t end_thinking_token_reg = end_thinking_tokens[0];
int64_t current_max_dec_len = max_dec_lens[idx];
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[idx] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num; token_offset++) {
const int token_idx = idx * tokens_per_step + token_offset;
int64_t next_token_reg = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
bool is_eos = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
if (current_limit_content_status < 1) {
bool should_transform = false;
// 当开启思考长度控制时,检查是否超时
if (max_think_len_reg > 0 && current_step >= max_think_len_reg) {
should_transform = true;
} else {
// 检查是否生成了EOS
for (int j = 0; j < eos_token_id_len; j++) {
if (eos_token_ids[j] == next_token_reg) {
is_eos = true;
should_transform = true;
break;
}
}
}
if (should_transform) {
// 强制将当前token替换为结束思考的token
next_token_reg = end_thinking_token_reg;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_content_status = 1;
condition_triggered = true; // 因为修改了token需要截断
// 只在EOS触发时清除stop_flags
if (is_eos && stop_flags[idx]) {
stop_flags[idx] = false;
}
}
}
// ======================= 思考结束处理 =======================
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型自己生成了 end_thinking_token
// 2. status == 1: 上一阶段强制注入了 end_thinking_token
if (current_limit_content_status < 2) {
if (next_token_reg == end_thinking_token_reg) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_content_status = 2;
}
}
next_tokens[token_idx] = next_token_reg;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[idx] -= discarded_tokens;
seq_lens_decoder[idx] -= discarded_tokens;
}
accept_num[idx] = new_accept_num;
limit_content_status[idx] = current_limit_content_status;
max_dec_lens[idx] = current_max_dec_len;
}
void LimitContentLen(const paddle::Tensor& next_tokens,
const paddle::Tensor& end_thinking_tokens,
const paddle::Tensor& max_content_len,
const paddle::Tensor& max_think_len,
const paddle::Tensor& step_idx,
const paddle::Tensor& eos_token_ids,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& limit_content_status,
const paddle::Tensor& enable_thinking,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int end_thinking_token_num = end_thinking_tokens.shape()[0];
const int end_length = eos_token_ids.shape()[0];
PD_CHECK(end_thinking_token_num == 1, "limit_content_len only support end_thinking_token_num = 1 for now.");
dim3 grid(1);
dim3 block(1024);
limit_content_len<<<grid, block>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_thinking_tokens.data<int64_t>(),
const_cast<int *>(max_content_len.data<int>()),
max_think_len.data<int>(),
const_cast<int64_t *>(step_idx.data<int64_t>()),
eos_token_ids.data<int64_t>(),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
const_cast<int *>(limit_content_status.data<int>()),
enable_thinking.data<bool>(),
const_cast<int *>(accept_num.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
tokens_per_step,
batch_size,
end_thinking_token_num,
end_length);
}
PD_BUILD_STATIC_OP(limit_content_len)
.Inputs({"next_tokens",
"end_thinking_tokens",
"max_content_len",
"max_think_len",
"step_idx",
"eos_token_ids",
"max_dec_len",
"limit_content_status",
"enable_thinking",
"accept_num",
"seq_lens_decoder",
"stop_flags"})
.Outputs({"next_tokens_out", "max_dec_len_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"},
{"max_dec_len", "max_dec_len_out"}})
.SetKernelFn(PD_KERNEL(LimitContentLen));

View File

@@ -269,7 +269,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
}
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int RoundType = 1>
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int Kthread = 512, int RoundType = 1>
__global__ void permute_x_kernel(const T *src_x,
const int64_t *topk_idx,
const float *topk_weights,
@@ -285,9 +285,9 @@ __global__ void permute_x_kernel(const T *src_x,
int *dst_indices,
int *cumsum_idx_gpu,
int64_t *token_nums_per_expert_cumsum,
int64_t *expert_idx_per_token,
int64_t *expert_idx_per_token, // [num_rows, moe_topk]
float max_bound = 127.0,
float min_bound = -127.0) { // [num_rows, moe_topk]
float min_bound = -127.0) {
const int src_token_idx = blockIdx.x;
const int tid = threadIdx.x;
constexpr int vec_size = sizeof(int4) / sizeof(T);
@@ -330,10 +330,17 @@ __global__ void permute_x_kernel(const T *src_x,
if (up_gate_proj_in_scale) {
for (int i = 0; i < vec_size; i++) {
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
if (RoundType == 0) {
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
if constexpr (std::is_same<OutT, int8_t>::value) {
// w4aint8
if (RoundType == 0) {
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
} else {
res_vec[i] = static_cast<OutT>(ClipFunc<float>(round(quant_value), min_bound, max_bound));
}
} else {
res_vec[i] = static_cast<OutT>(round(quant_value));
// w4afp8
float value = ClipFunc<float>(quant_value, min_bound, max_bound);
res_vec[i] = static_cast<OutT>(value);
}
}
} else {
@@ -373,6 +380,10 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
auto stream = input.stream();
auto place = input.place();
const int gridx = min(132 * 8, num_rows);
@@ -420,6 +431,50 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
-127.0
);
}
} else if (moe_quant_type == "w4afp8") {
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);
} else if (num_experts_per_rank == 16) {
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);
}
} else {
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
@@ -493,7 +548,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
auto permute_input = GetEmptyTensor(
{token_nums_this_rank, hidden_size},
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : input_type,
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN : input_type,
place);
auto num_experts_per_rank_tensor = GetEmptyTensor(
{num_experts_per_rank},

View File

@@ -88,7 +88,7 @@ struct nv_type_traits<int8_t> {
constexpr int kLogN = 7; \
__VA_ARGS__ \
} else { \
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupport!", logN)); \
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \
}
#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \
@@ -108,7 +108,7 @@ struct nv_type_traits<int8_t> {
constexpr int VEC_SIZE = 1; \
__VA_ARGS__ \
} else { \
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupport!", vec_size)); \
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", vec_size)); \
}
#define DISPATCH_logN(logN, kLogN, ...) \
@@ -605,26 +605,6 @@ void moe_fast_hardamard_kernel(const T *x,
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
}
if constexpr (kNChunks > 1) {
// T x_vals_transposed[VecSize][kNChunks] = {init_value};
// #pragma unroll
// for (int c = 0; c < kNChunks; ++c) {
// #pragma unroll
// for (int i = 0; i < VecSize; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; }
// }
// if constexpr (kNChunks == 28) {
// hadamard_mult_thread_chunk_28<VecSize>(x_vals_transposed);
// } else if constexpr (kNChunks == 36) {
// hadamard_mult_thread_chunk_36<VecSize>(x_vals_transposed);
// } else {
// constexpr int kLogNChunks = cilog2(kNChunks);
// static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
// hadamard_mult_thread<kLogNChunks, VecSize>(x_vals_transposed);
// }
// #pragma unroll
// for (int c = 0; c < kNChunks; ++c) {
// #pragma unroll
// for (int i = 0; i < VecSize; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; }
// }
if constexpr (kNChunks == 28) {
hadamard_mult_thread_28_transpose<T, VecSize>(x_vals);
} else if constexpr (kNChunks == 36) {

View File

@@ -72,6 +72,285 @@ __host__ __device__ constexpr static U arrayConvert(T const& input)
return u;
}
struct uint8 {
uint4 u;
uint4 v;
};
template<int BYTES> struct BytesToType {};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
};
template <typename InType, typename OutType>
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float scale,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * static_cast<float>(input);
return static_cast<OutType>(ClipFunc<float>(quant_value, min_bound, max_bound));
}
template <typename T, typename OutT, int VecSize, int Kthread>
__global__ void masked_quantize_moe_input_kernel(const T* permuted_inputs,
const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
OutT* out) {
using LoadT = AlignedVector<T, VecSize>;
using LoadOutT = AlignedVector<OutT, VecSize>;
LoadT input_vec;
LoadOutT output_vec;
float scale_factor = -7.0f / 512.0f;
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
const auto expert_id = token_idx / num_max_tokens_per_expert;
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
token_idx += num_iters_to_next_expert * gridDim.x;
continue;
}
int64_t expert_idx = expert_idx_per_token[token_idx];
float quant_scale = quant_scales[expert_idx];
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
thread_row_sum += static_cast<float>(output_vec[i]);
}
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T, typename OutT, int VecSize, int Kthread>
__global__ void quantize_moe_input_kernel(const T* permuted_inputs,
const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
OutT* out) {
using LoadT = AlignedVector<T, VecSize>;
using LoadOutT = AlignedVector<OutT, VecSize>;
LoadT input_vec;
LoadOutT output_vec;
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
float scale_factor = -7.0f / 512.0f;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
int64_t expert_idx = expert_idx_per_token[token_idx];
float quant_scale = quant_scales[expert_idx];
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
thread_row_sum += static_cast<float>(output_vec[i]);
}
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T, typename OutT>
void quantize_moe_input(
const T* permuted_inputs,
const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
OutT* out,
cudaStream_t stream) {
constexpr int VecSize = 16 / sizeof(T);
constexpr int threads_per_block = 128;
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
auto kernel = used_in_ep_low_latency ? masked_quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block> : quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, threads_per_block, 0);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
kernel<<<grid, threads_per_block, 0, stream>>>(
permuted_inputs,
expert_idx_per_token,
quant_scales,
quant_max_bound,
quant_min_bound,
token_num,
dim,
permuted_input_row_sum,
recv_expert_count,
num_max_tokens_per_expert,
out);
}
template <typename T, int VecSize, int Kthread>
__global__ void masked_compute_row_sum_kernel(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert) {
using LoadT = AlignedVector<T, VecSize>;
LoadT input_vec;
float scale_factor = -7.0f / 512.0f;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
const auto expert_id = token_idx / num_max_tokens_per_expert;
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
token_idx += num_iters_to_next_expert * gridDim.x;
continue;
}
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
thread_row_sum += static_cast<float>(input_vec[i]);
}
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T, int VecSize, int Kthread>
__global__ void compute_row_sum_kernel(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert) {
using LoadT = AlignedVector<T, VecSize>;
LoadT input_vec;
float scale_factor = -7.0f / 512.0f;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
thread_row_sum += static_cast<float>(input_vec[i]);
}
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T>
void compute_row_sum(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
cudaStream_t stream) {
constexpr int VecSize = 16 / sizeof(T);
constexpr int threads_per_block = 128;
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
auto kernel = used_in_ep_low_latency ? masked_compute_row_sum_kernel<T, VecSize, threads_per_block> : compute_row_sum_kernel<T, VecSize, threads_per_block>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, threads_per_block, 0);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
kernel<<<grid, threads_per_block, 0, stream>>>(
permuted_inputs,
token_num,
dim,
permuted_input_row_sum,
recv_expert_count,
num_max_tokens_per_expert);
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing
// the output in the softmax kernel when we extend this module to support

View File

@@ -20,6 +20,7 @@
#include "helper.h"
#include "moe/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
#include "w4afp8_gemm/w4afp8_gemm.h"
template <paddle::DataType T>
void MoeFFNKernel(const paddle::Tensor& permute_input,
@@ -33,7 +34,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency) {
bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
@@ -60,19 +62,22 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
Allocator* allocator = paddle::GetAllocator(place);
Allocator::AllocationPtr workspace;
if (quant_method == "weight_only_int4" || quant_method == "w4a8") {
if (quant_method == "weight_only_int4" || quant_method == "w4a8" || quant_method == "w4afp8") {
inter_dim = inter_dim * 2;
}
if (quant_method == "w4a8") {
if (quant_method == "w4a8" || quant_method == "w4afp8") {
workspace = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * workspace_size);
}
const int64_t inter_size = inter_dim;
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
int num_experts_ = num_experts;
int num_max_tokens_per_expert;
int num_max_tokens_per_expert = 256;
int expanded_active_expert_rows;
paddle::Tensor fc1_out_tensor;
@@ -161,13 +166,49 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
reinterpret_cast<NvType *>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
inter_size,
hidden_size,
reinterpret_cast<char*>(workspace->ptr()),
workspace_size,
num_experts,
stream);
} else if (quant_method == "w4afp8") {
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
Allocator::AllocationPtr ffn1_input_row_sum;
ffn1_input_row_sum = allocator->Allocate(
sizeof(float) * expanded_active_expert_rows);
compute_row_sum(
permute_input.data<data_t_fp8>(),
expanded_active_expert_rows,
hidden_size,
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
stream);
float* row_scale = nullptr;
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8 *>(permute_input.data<data_t_fp8>()),
reinterpret_cast<const DataType_fp8 *>(up_gate_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType *>(fc1_out),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
num_max_tokens_per_expert,
num_experts,
inter_size,
hidden_size,
stream);
} else {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm_bias_act(
@@ -194,7 +235,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
}
auto act_out = act_out_tensor.data<data_t>();
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm(
@@ -267,13 +307,73 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
reinterpret_cast<NvType *>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
hidden_size,
inter_size / 2,
reinterpret_cast<char*>(workspace->ptr()),
workspace_size,
num_experts,
stream);
} else if (quant_method == "w4afp8") {
data_t *ffn2_shift = nullptr;
data_t *ffn2_smooth = nullptr;
float* row_scale = nullptr;
Allocator::AllocationPtr fp8_act_out;
fp8_act_out = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
Allocator::AllocationPtr ffn2_input_row_sum;
ffn2_input_row_sum = allocator->Allocate(
sizeof(float) * expanded_active_expert_rows);
// note(yuanxiaolan): optimize this
MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
act_out_tensor.data<data_t>(),
stream
);
quantize_moe_input<data_t, data_t_fp8>(act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<data_t_fp8 *>(fp8_act_out->ptr()),
stream
);
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8 *>(fp8_act_out->ptr()),
reinterpret_cast<const DataType_fp8 *>(down_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType*>(ffn_out_data),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
num_max_tokens_per_expert,
num_experts,
hidden_size,
inter_size / 2,
stream);
} else {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm(
@@ -302,10 +402,12 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency) {
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
cudaCheckError();
const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype();
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
switch (t_type) {
@@ -320,7 +422,9 @@ paddle::Tensor MoeExpertFFNFunc(
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out, used_in_ep_low_latency);
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums);
break;
case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
@@ -333,7 +437,9 @@ paddle::Tensor MoeExpertFFNFunc(
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out, used_in_ep_low_latency);
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
@@ -351,7 +457,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency) {
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
@@ -361,7 +468,9 @@ std::vector<paddle::Tensor> MoeExpertFFN(
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method, used_in_ep_low_latency)};
quant_method,
used_in_ep_low_latency,
estimate_total_token_nums)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
@@ -375,7 +484,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const paddle::optional<std::vector<int64_t>>& down_proj_in_scale_shape,
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
const std::string& quant_method,
const bool used_in_ep_low_latency) {
const bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
return {permute_input_shape};
}
@@ -388,8 +498,9 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
const std::string &quant_method, const bool used_in_ep_low_latency) {
if (quant_method == "w4a8") {
const std::string &quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
if (quant_method == "w4a8" || quant_method == "w4afp8") {
return {up_gate_proj_scale_dtype.get()};
} else {
return {permute_input_dtype};
@@ -460,7 +571,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
paddle::Optional("down_proj_in_scale"),
paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));

View File

@@ -26,8 +26,10 @@ __global__ void process_splitwise_prefill(
int64_t* step_idx,
bool* not_need_stop,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
@@ -35,11 +37,12 @@ __global__ void process_splitwise_prefill(
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len) {
const int base_model_draft_tokens_len,
const int pre_ids_len) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t not_stop_flag = 0;
@@ -92,8 +95,10 @@ __global__ void draft_model_preprocess_kernel(
int64_t* step_idx,
bool* not_need_stop,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
@@ -101,11 +106,12 @@ __global__ void draft_model_preprocess_kernel(
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len) {
const int base_model_draft_tokens_len,
const int pre_ids_len) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t not_stop_flag = 0;
@@ -113,13 +119,16 @@ __global__ void draft_model_preprocess_kernel(
int tid = threadIdx.x;
if (tid < bsz) {
auto base_model_step_idx_now = base_model_step_idx[tid];
const int32_t base_model_step_idx_now = base_model_step_idx[tid];
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
auto accept_num_now = accept_num[tid];
const int32_t accept_num_now = accept_num[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_tokens_len;
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
#pragma unroll
for (int i = 1; i < base_model_draft_tokens_len; i++) {
base_model_draft_tokens_now[i] = -1;
@@ -133,14 +142,12 @@ __global__ void draft_model_preprocess_kernel(
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1;
// 1. first token
if (base_model_step_idx_now == 0) {
seq_lens_this_time[tid] = 0;
not_stop_flag = 0;
} else if (seq_lens_encoder[tid] > 0) {
if (seq_lens_encoder[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder = seq_lens_encoder[tid];
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
pre_ids_now[0] = base_model_first_token;
int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
@@ -149,24 +156,24 @@ __global__ void draft_model_preprocess_kernel(
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else if (accept_num_now <=
max_draft_token) /*Accept partial draft tokens*/ {
// Base Model reject stop
} else {
if (stop_flags[tid]) {
stop_flags[tid] = false;
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
step_idx[tid] = base_model_step_idx[tid];
// TODO: check
seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time;
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
} else {
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
step_idx[tid] -= max_draft_token - accept_num_now;
// 2: Last base model generated token and first MTP token
seq_lens_decoder[tid] -= num_model_step - 1;
step_idx[tid] -= num_model_step - 1;
}
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
draft_tokens_now[0] = modified_token;
seq_lens_this_time[tid] = 1;
} else /*Accept all draft tokens*/ {
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
seq_lens_this_time[tid] = 2;
for (int i = 0; i < accept_num_now; i++) {
draft_tokens_now[i] = accept_tokens_now[i];
const int pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i);
const int64_t accept_token = accept_tokens_now[i];
pre_ids_now[pre_id_pos] = accept_token;
}
seq_lens_this_time[tid] = accept_num_now;
}
} else {
stop_flags[tid] = true;
@@ -194,8 +201,10 @@ void DispatchRunner(
int64_t* step_idx,
bool* not_need_stop,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
@@ -203,11 +212,12 @@ void DispatchRunner(
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool splitwise_prefill) {
constexpr int BlockSize = 512;
if (splitwise_prefill) {
@@ -222,8 +232,10 @@ void DispatchRunner(
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
@@ -231,11 +243,12 @@ void DispatchRunner(
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len);
base_model_draft_tokens_len,
pre_ids_len);
} else {
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
@@ -248,8 +261,10 @@ void DispatchRunner(
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
@@ -257,11 +272,12 @@ void DispatchRunner(
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len);
base_model_draft_tokens_len,
pre_ids_len);
}
}
@@ -276,8 +292,10 @@ void DispatchTokenMode(
int64_t* step_idx,
bool* not_need_stop,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
@@ -285,11 +303,12 @@ void DispatchTokenMode(
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int max_draft_token,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill) {
if (truncate_first_token) {
@@ -304,8 +323,10 @@ void DispatchTokenMode(
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
@@ -313,11 +334,12 @@ void DispatchTokenMode(
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
splitwise_prefill
);
} else {
@@ -332,8 +354,10 @@ void DispatchTokenMode(
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
@@ -341,11 +365,12 @@ void DispatchTokenMode(
base_model_is_block_step,
base_model_draft_tokens,
bsz,
max_draft_token,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
splitwise_prefill
);
}
@@ -363,21 +388,24 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const int num_model_step,
const bool truncate_first_token,
const bool splitwise_prefill) {
int real_bsz = seq_lens_this_time.shape()[0];
int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1];
int draft_tokens_len = draft_tokens.shape()[1];
int pre_ids_len = pre_ids.shape()[1];
auto cu_stream = seq_lens_this_time.stream();
constexpr int BlockSize = 512;
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
@@ -395,8 +423,10 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(),
@@ -404,11 +434,12 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
base_model_is_block_step.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
max_draft_token,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill);
@@ -429,8 +460,10 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"step_idx",
"not_need_stop",
"batch_drop",
"pre_ids",
"accept_tokens",
"accept_num",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder",
"base_model_seq_lens_decoder",
"base_model_step_idx",
@@ -445,8 +478,9 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_decoder_out",
"step_idx_out",
"not_need_stop_out",
"batch_drop_out"})
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
"batch_drop_out",
"pre_ids_out"})
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},
@@ -455,5 +489,6 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"not_need_stop", "not_need_stop_out"},
{"batch_drop", "batch_drop_out"}})
{"batch_drop", "batch_drop_out"},
{"pre_ids", "pre_ids_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));

View File

@@ -63,10 +63,9 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
base_model_draft_tokens_now[substep + 1] = token_this_time;
for (int i = 0; i < seq_len_this_time; ++i) {
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
}
step_idx[tid] += seq_len_this_time;
pre_ids_now[step_idx[tid]] = token_this_time;
} else {
token_this_time = next_tokens_start[0];

View File

@@ -49,9 +49,7 @@ __global__ void ComputeOrderKernel(
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++;
}
// 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) {
// 3. New end
// 2. Base model stop at last verify-step.
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: base=0. draft !=0 \n", i);
@@ -61,20 +59,25 @@ __global__ void ComputeOrderKernel(
// 4. stopped
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
} else {
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
} else /*Accept all draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num > actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 2] = out_offset++;
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
for (int i = 0; i < accept_num; i++) {
position_map[in_offset++] = out_offset++;
}
in_offset += cur_base_model_seq_lens_this_time - accept_num;
// (liuzichang): Temperary Reserved for debug
// if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
// #ifdef DEBUG_EAGLE_KERNEL
// printf("batch %d: accept_num <= actual_draft_token_num \n", i);
// #endif
// position_map[in_offset + accept_num - 1] = out_offset++;
// in_offset += cur_base_model_seq_lens_this_time;
// } else /*Accept all draft tokens*/ {
// #ifdef DEBUG_EAGLE_KERNEL
// printf("batch %d: accept_num > actual_draft_token_num \n", i);
// #endif
// position_map[in_offset + accept_num - 2] = out_offset++;
// position_map[in_offset + accept_num - 1] = out_offset++;
// in_offset += cur_base_model_seq_lens_this_time;
// }
}
}
output_token_num[0] = out_offset;

View File

@@ -0,0 +1,214 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
int sum_mixed(const int *value, int num) {
int sum_value = 0;
for (int i = 0; i <= num; i++) {
sum_value += value[i];
}
return sum_value;
}
void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size = 3,
int min_ngram_size = 1,
const int max_draft_tokens = 10) {
int threshold = 1024;
// dynamic in future
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
}
int unprocessed_batch_size = 0;
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
if (seq_lens_decoder[batch_idx] > 0) {
unprocessed_batch_size++;
}
}
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
const int ori_seq_len_this_time = seq_lens_this_time[batch_idx];
int max_draft_tokens_query = std::min(static_cast<int64_t>(
max_draft_tokens - ori_seq_len_this_time + 1), max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) {
continue;
}
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
unprocessed_batch_size--;
auto sum_token_num = sum_mixed(seq_lens_this_time, batch_idx);
int left_min_token_num = unprocessed_batch_size;
if (sum_token_num + max_draft_tokens_query + left_min_token_num > threshold) {
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
max_draft_tokens_query = std::min(max_draft_tokens_query, tmp_max_draft_tokens);
}
if (sum_token_num + left_min_token_num >= threshold - 1) {
continue;
}
bool match_global = false;
// apply ngram_match in input_ids
for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size && !match_global; --ngram_size) {
// Extract the last n tokens as our search ngram
if (cur_step_idx < ngram_size) {
continue;
}
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
// Iterate through sliding windows of size ngram_size
// bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global; ++i) {
// Check if the current window matches the ngram
bool match_local = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_input_ids[i + j]) {
match_local = false;
break;
}
}
if (match_local) {
int64_t start_idx = i + ngram_size;
int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_input_ids_len);
if (start_idx >= end_idx)
continue;
int64_t cur_draft_token_num = end_idx - start_idx;
seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
// To break the current batch_idx for-loop
match_global = true;
break;
}
}
// apply ngram_match in generated tokens
if (!match_global) {
for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global; ++i) {
// Check if the current window matches the ngram
bool match_local = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_pre_ids[i + j]) {
match_local = false;
break;
}
}
if (match_local) {
int64_t start_idx = i + ngram_size;
int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_step_idx);
int64_t cur_draft_token_num = end_idx - start_idx;
if (start_idx >= end_idx)
continue;
// printf("match in Output with Ngram_size %d. %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx, end_idx);
seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
match_global = true;
break;
}
}
}
}
}
}
void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens) {
auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];
auto pre_ids_shape = pre_ids.shape();
const int64_t pre_ids_stride = pre_ids_shape[1];
auto draft_tokens_shape = draft_tokens.shape();
const int64_t draft_tokens_stride = draft_tokens_shape[1];
const int64_t max_batch_size = seq_lens_this_time.shape()[0];
find_candidate_pred_tokens_mixed(input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
}
PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
.Inputs({"input_ids",
"input_ids_len",
"pre_ids",
"step_idx",
"draft_token_num",
"draft_tokens",
"seq_lens_this_time",
"seq_lens_decoder",
"max_dec_len"})
.Attrs({"max_ngram_size: int", "min_ngram_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
.SetKernelFn(PD_KERNEL(HybridMtpNgram))
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}});

View File

@@ -23,14 +23,7 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct msgdata {
int64_t mtype;
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, accept_num*bsz, tokens...
};
#include "speculate_msg.h"
void SpeculateGetOutput(const paddle::Tensor& x,
int64_t rank_id,
@@ -54,7 +47,7 @@ void SpeculateGetOutput(const paddle::Tensor& x,
msg_queue_id = inference_msg_queue_id_from_env;
}
static struct msgdata msg_rcv;
static struct speculate_msgdata msg_rcv;
static key_t key = ftok("./", msg_queue_id);

View File

@@ -1,69 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void SpeculateHydraSetScoreThresholdKernel(
float* threshold,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* accept_num,
const int real_bsz,
const float default_threshold = 0.3,
const float upper_threshold = 0.8,
const float lower_threshold = 0.0,
const float threshold_step = 0.1,
const float threshold_step_fac = 0.5) {
for (int bid = threadIdx.x; bid < real_bsz; bid += blockDim.x) {
if (seq_lens_encoder[bid] > 0) {
threshold[bid] = default_threshold;
} else if (seq_lens_this_time[bid] <= 1) {
continue;
} else if (accept_num[bid] >= seq_lens_this_time[bid] &&
threshold[bid] >
lower_threshold + threshold_step * threshold_step_fac) {
threshold[bid] -= threshold_step * threshold_step_fac;
} else if (accept_num[bid] < seq_lens_this_time[bid] &&
threshold[bid] < upper_threshold - threshold_step) {
threshold[bid] += threshold_step;
}
}
}
void SpeculateHydraSetScoreThreshold(const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num,
const paddle::Tensor& threshold) {
auto cu_stream = seq_lens_this_time.stream();
std::vector<int64_t> seq_lens_this_time_shape = seq_lens_this_time.shape();
const int bsz = seq_lens_this_time_shape[0];
SpeculateHydraSetScoreThresholdKernel<<<1, 256, 0, cu_stream>>>(
const_cast<float*>(threshold.data<float>()),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
accept_num.data<int>(),
bsz);
}
PD_BUILD_STATIC_OP(speculate_hydra_set_score_threshold)
.Inputs(
{"seq_lens_this_time", "seq_lens_encoder", "accept_num", "threshold"})
.Outputs({"threshold_out"})
.SetInplaceMap({{"threshold", "threshold_out"}})
.SetKernelFn(PD_KERNEL(SpeculateHydraSetScoreThreshold));

View File

@@ -1,68 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
__global__ void hydra_update_this_time(int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const float* topk_scores,
const float* score_threshold,
int real_bsz,
int idx) {
int linear_idx = threadIdx.x;
// verify and set stop flags
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
if (seq_lens_encoder[linear_idx] == 0 &&
seq_lens_decoder[linear_idx] != 0) {
if (topk_scores[linear_idx] > score_threshold[linear_idx] &&
seq_lens_this_time[linear_idx] == idx + 1) {
seq_lens_this_time[linear_idx]++;
}
} else if (seq_lens_encoder[linear_idx] == 0 &&
seq_lens_decoder[linear_idx] == 0) {
seq_lens_this_time[linear_idx] = 0;
}
}
}
void HydraUpdateThisTime(const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& topk_scores,
const paddle::Tensor& score_threshold,
const int real_bsz,
const int idx) {
constexpr int BlockSize = 512;
hydra_update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
const_cast<int*>(seq_lens_this_time.data<int>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
topk_scores.data<float>(),
score_threshold.data<float>(),
real_bsz,
idx);
}
PD_BUILD_STATIC_OP(speculate_hydra_update_seqlens_this_time)
.Inputs({"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"topk_scores",
"score_threshold"})
.Outputs({"seq_lens_this_time_out"})
.Attrs({"real_bsz: int", "idx: int"})
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
.SetKernelFn(PD_KERNEL(HydraUpdateThisTime));

View File

@@ -1,149 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
template <typename T, int VecSize>
__global__ void RebuildAppendPaddingKernel(
T *out,
const T *full_hidden_states,
const int *cum_offset,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *output_padding_offset,
const int seq_len,
const int dim_embed,
const size_t elem_nums) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
const int out_token_id = i / dim_embed;
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
const int bi = ori_token_id / seq_len;
int seq_id = 0;
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
else if (seq_len_encoder[bi] != 0) {
seq_id = seq_len_encoder[bi] - 1;
}
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
const int bias_idx = i % dim_embed;
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
Store<T, VecSize>(src_vec, &out[i]);
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> DispatchDtype(
const paddle::Tensor& full_hidden_states,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& output_padding_offset,
const int max_seq_len) {
// src: [token_num, dim_embed]
// dst: [batch_size, 1, dim_embed]
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
int dim_embed = full_hidden_states.shape()[1];
int output_token_num = output_padding_offset.shape()[0];
int elem_nums = output_token_num * dim_embed;
constexpr int PackSize = VEC_16B / sizeof(DataType_);
assert(elem_nums % PackSize == 0);
auto out = paddle::full({output_token_num, dim_embed}, 0, full_hidden_states.dtype(), full_hidden_states.place());
int pack_num = elem_nums / PackSize;
const int threads_per_block = 128;
int grid_size = 1;
GetNumBlocks(pack_num, &grid_size);
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
cum_offsets.data<int32_t>(),
seq_len_encoder.data<int32_t>(),
seq_len_decoder.data<int32_t>(),
output_padding_offset.data<int32_t>(),
max_seq_len,
dim_embed,
elem_nums);
return {out};
}
std::vector<paddle::Tensor> RebuildAppendPadding(
const paddle::Tensor& full_hidden_states,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& output_padding_offset,
const int max_seq_len) {
switch (full_hidden_states.dtype()) {
case paddle::DataType::BFLOAT16:
return DispatchDtype<paddle::DataType::BFLOAT16>(
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
case paddle::DataType::FLOAT16:
return DispatchDtype<paddle::DataType::FLOAT16>(
full_hidden_states, cum_offsets, seq_len_encoder, seq_len_decoder, output_padding_offset, max_seq_len);
default:
PD_THROW("Unsupported data type.");
}
}
std::vector<std::vector<int64_t>> RebuildAppendPaddingInferShape(
const std::vector<int64_t>& full_hidden_states_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& seq_len_encoder_shape,
const std::vector<int64_t>& seq_len_decoder_shape,
const std::vector<int64_t>& output_padding_offset_shape) {
const int64_t output_token_num = output_padding_offset_shape[0];
const int64_t dim_embed = full_hidden_states_shape[1];
std::vector<int64_t> out_shape = {output_token_num, dim_embed};
return {out_shape};
}
std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
const paddle::DataType& full_hidden_states_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& seq_len_encoder_dtype,
const paddle::DataType& seq_len_decoder_dtype,
const paddle::DataType& output_padding_offset_dtype) {
return {full_hidden_states_dtype};
}
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
.Inputs({"full_hidden_states",
"cum_offsets",
"seq_len_encoder",
"seq_len_decoder",
"output_padding_offset"})
.Attrs({"max_seq_len: int"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));

View File

@@ -23,14 +23,7 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype;
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, tokens
};
#include "speculate_msg.h"
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
@@ -62,7 +55,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static struct msgdata msg_sed;
static struct speculate_msgdata msg_sed;
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);

View File

@@ -15,7 +15,7 @@
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void speculate_update_v3(int *seq_lens_encoder,
__global__ void speculate_update(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *not_need_stop,
int64_t *draft_tokens,
@@ -90,7 +90,7 @@ __global__ void speculate_update_v3(int *seq_lens_encoder,
}
}
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
@@ -108,7 +108,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
constexpr int BlockSize = 512;
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_update_v3<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
@@ -130,7 +130,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_update_v3)
PD_BUILD_STATIC_OP(speculate_update)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
@@ -152,4 +152,4 @@ PD_BUILD_STATIC_OP(speculate_update_v3)
{"not_need_stop", "not_need_stop_out"},
{"draft_tokens", "draft_tokens_out"},
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
.SetKernelFn(PD_KERNEL(SpeculateUpdateV3));
.SetKernelFn(PD_KERNEL(SpeculateUpdate));

View File

@@ -1,55 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
__global__ void update_this_time(int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
int real_bsz,
int value) {
int linear_idx = threadIdx.x;
// verify and set stop flags
for (; linear_idx < real_bsz; linear_idx += blockDim.x) {
if (seq_lens_encoder[linear_idx] == 0 &&
seq_lens_decoder[linear_idx] != 0) {
seq_lens_this_time[linear_idx] = value;
} else if (seq_lens_encoder[linear_idx] == 0 &&
seq_lens_decoder[linear_idx] == 0) {
seq_lens_this_time[linear_idx] = 0;
}
}
}
void UpdateThisTime(const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const int real_bsz,
const int value) {
constexpr int BlockSize = 512;
update_this_time<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
const_cast<int*>(seq_lens_this_time.data<int>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
real_bsz,
value);
}
PD_BUILD_STATIC_OP(speculate_update_seq_lens_this_time)
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
.Outputs({"seq_lens_this_time_out"})
.Attrs({"real_bsz: int", "value: int"})
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}})
.SetKernelFn(PD_KERNEL(UpdateThisTime));

View File

@@ -1,146 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
template <int THREADBLOCK_SIZE>
__global__ void speculate_update(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *not_need_stop,
int64_t *draft_tokens,
int *actual_draft_token_nums,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_this_time,
const bool *is_block_step,
const int real_bsz,
const int max_draft_tokens) {
const int bid = threadIdx.x;
const int accept_num_now = accept_num[bid];
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
if (stop_flags[bid]) {
stop_flag_now_int = 1;
}
if (seq_lens_encoder[bid] == 0) {
seq_lens_decoder[bid] += accept_num_now;
}
if (seq_lens_this_time[bid] > 1 &&
seq_lens_encoder[bid] ==
0) { // 对于append模式需要根据接收与否确定是否要降低下次draft
// token的数量
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
if (accept_num_now - 1 == current_actual_draft_token_num) {
if (current_actual_draft_token_num + 2 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 2;
} else if (current_actual_draft_token_num + 1 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 1;
} else {
actual_draft_token_nums[bid] = max_draft_tokens - 1;
}
} else {
actual_draft_token_nums[bid] =
actual_draft_token_nums[bid] - 1 >= 1
? actual_draft_token_nums[bid] - 1
: 1;
}
}
if (seq_lens_encoder[bid] != 0) {
seq_lens_decoder[bid] += seq_lens_encoder[bid];
seq_lens_encoder[bid] = 0;
}
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
if (stop_flag_now_int) {
seq_lens_decoder[bid] = 0;
}
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < real_bsz;
}
}
void SpeculateUpdateV2(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step) {
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
constexpr int BlockSize = 512;
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int *>(actual_draft_token_nums.data<int>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_this_time.data<int>(),
is_block_step.data<bool>(),
real_bsz,
max_draft_tokens);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_update_v2)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
"draft_tokens",
"actual_draft_token_nums",
"accept_tokens",
"accept_num",
"stop_flags",
"seq_lens_this_time",
"is_block_step"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"not_need_stop_out",
"draft_tokens_out",
"actual_draft_token_nums_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"not_need_stop", "not_need_stop_out"},
{"draft_tokens", "draft_tokens_out"},
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
.SetKernelFn(PD_KERNEL(SpeculateUpdateV2));

View File

@@ -0,0 +1,154 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
using namespace cute;
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
class SmemLayoutB, class SmemLayoutC>
struct SharedStorage {
union {
struct {
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
};
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
};
struct {
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
};
};
template<int kBlockM_, int kBlockN_, int kBlockK_,
int kNWarps_, int kStages_,
int kTiles_, int M_,
int TokenPackSize_,
int TAIL_N_ = 0,
int kClusterM_ = 1,
typename elem_type=cutlass::float_e4m3_t,
typename OutputType = cutlass::bfloat16_t>
struct Kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using ElementOutput = OutputType;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kTiles = kTiles_;
static constexpr int TokenPackSize = TokenPackSize_;
static constexpr int M = M_;
static constexpr int TAIL_N = TAIL_N_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
static constexpr int kStages = kStages_;
static_assert(kStages > 1);
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma_TAIL = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
AtomLayoutMNK{}));
using SmemLayoutAtomA = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
using SmemLayoutA = decltype(
tile_to_shape(SmemLayoutAtomA{},
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
using SmemLayoutAtomB = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutB = decltype(
tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomB_TAIL = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
using SmemLayoutB_TAIL = decltype(
tile_to_shape(SmemLayoutAtomB_TAIL{},
make_shape(
shape<1>(TileShape_MNK_TAIL{}),
shape<2>(TileShape_MNK_TAIL{}),
Int<kStages>{})
));
using SmemLayoutAtomC = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, ElementOutput,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
using SharedStorage = SharedStorage<
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
using TiledCopyCThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyCValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using TiledCopyC = decltype(make_tiled_copy(
TiledCopyCAtom{},
TiledCopyCThrLayout{}, // Thr layout
TiledCopyCValLayout{} // Val layout
));
};

View File

@@ -0,0 +1,405 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
// #include "named_barrier.hpp"
#include "utils.hpp"
using namespace cute;
template <typename Ktraits>
struct CollectiveMainloopFwd {
using Element = typename Ktraits::Element;
using ElementOutput = typename Ktraits::ElementOutput;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using ElementAccum = typename Ktraits::ElementAccum;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int TAIL_N = Ktraits::TAIL_N;
static constexpr int kBlockK = Ktraits::kBlockK;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kTiles = Ktraits::kTiles;
static constexpr int M = Ktraits::M;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
using GmemTiledCopy = cute::SM90_TMA_LOAD;
using SmemLayoutA = typename Ktraits::SmemLayoutA;
using SmemLayoutB = typename Ktraits::SmemLayoutB;
using SmemLayoutC = typename Ktraits::SmemLayoutC;
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
using StrideT = cute::Shape<int64_t, _1, int64_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using TMA_A = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}
),
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{})));
using TMA_B = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}
),
take<0, 2>(SmemLayoutB{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})));
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using SmemCopyAtomAB = typename Ktraits::SmemCopyAtomAB;
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
using TiledCopyC = typename Ktraits::TiledCopyC;
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
struct Arguments {
Element const* ptr_A;
LayoutT layout_A;
Element const* ptr_B;
LayoutT layout_B;
ElementOutput * ptr_C;
LayoutT layout_C;
const float *weight_scale;
const float *input_row_sum;
const int64_t * tokens;
};
struct Params {
LayoutT layout_A;
LayoutT layout_B;
TMA_A tma_load_A;
TMA_B tma_load_B;
ElementOutput * ptr_C;
const float *weight_scale;
const float *input_row_sum;
const int64_t * tokens;
};
Params static
to_underlying_arguments(Arguments const& args) {
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
TMA_A tma_load_A = make_tma_copy(
GmemTiledCopy{},
mA,
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{}));
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
TMA_B tma_load_B = make_tma_copy(
GmemTiledCopy{},
mB,
SmemLayoutB{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}));
return {args.layout_A, args.layout_B, tma_load_A, tma_load_B,
args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
}
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO & tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
const float *input_row_sum,
const float *weight_scale,
const int64_t tokens,
const int64_t pre_fix_tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
using packHalf = typename PackedHalf<ElementOutput>::Type;
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
#pragma unroll
for (int i = 0; i < size(tOrO); i+=4) {
const int sum_idx = i * 2;
tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0];
tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0];
tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1];
tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1];
*reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]);
}
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
constexpr int k_copy_times = CUR_N / 16;
#pragma unroll
for (int i = 0; i < k_copy_times; i++) {
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
asm volatile (
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
#endif
}
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
const int reamin_tokens = tokens - bidn * kBlockN;
const int col = tidx % 2;
constexpr int kPackSize = 16 / sizeof(ElementOutput);
constexpr int kNumVecElem = kBlockM / kPackSize;
constexpr int copy_len = CUR_N * kNumVecElem;
#pragma unroll
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
const int idx_div2 = idx / 2;
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
const int store_global_idx = store_idx * 2 + col;
const int row = store_global_idx / kNumVecElem;
const int col = store_global_idx % kNumVecElem;
if (row >= reamin_tokens) {
continue;
}
const int offset = row * (M / kPackSize) + col;
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
}
}
template <typename MTensor>
CUTLASS_DEVICE auto get_local_no_packed_tensor(
const MTensor &mB,
const int pre_fix_token,
const int actual_token,
const int bidn) const {
auto g_offset = local_tile(
mB(_, _, 0),
cute::make_shape(1, size<1>(mB)),
make_coord(pre_fix_token, _0{}));
auto g_tensor = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_token, size<2>(mB)),
g_offset.stride()
));
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
return gB;
}
template <typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState& smem_pipe_write,
SharedStorage &shared_storage,
const int tokens,
const int pre_fix_tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), make_coord(bidm, _));
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
const int kIters = kTiles / kStages;
if constexpr (TokenPackSize == 0) {
Tensor gB = get_local_no_packed_tensor(
mB,
pre_fix_tokens,
tokens,
bidn);
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
} else {
auto mB_this_batch = make_tensor(
mB(_, _, bidb).data(),
make_layout(
cute::make_shape(tokens, size<1>(mB)),
mB.stride()
));
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
}
}
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
TiledMma tiled_mma,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
FrgTensorO &tSrS,
const int tidx) {
using sMemBLayout = std::conditional_t<
CUR_N == kBlockN,
SmemLayoutB,
SmemLayoutB_TAIL
>;
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
auto threadMma = tiled_mma.get_thread_slice(tidx);
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
Tensor tSrB = threadMma.partition_fragment_B(sB);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
const int kIters = kTiles / kStages;
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s));
consumer_wait(pipeline, smem_pipe_read);
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
#pragma unroll
for (int i = 0; i < kTiles % kStages; ++i) {
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i));
consumer_wait(pipeline, smem_pipe_read);
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
};

View File

@@ -0,0 +1,114 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/tensor.hpp>
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
using namespace cute;
template<typename T>
struct PackedHalf;
template<>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
};
template<>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
};
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <int numel>
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) {
#pragma unroll
for (int i = 0; i < numel; ++i) {
dst1[i] = (src[i] >> 4) & 0x0f0f0f0f;
dst2[i] = src[i] & 0x0f0f0f0f;
}
}
template <int wg_wait=0, bool arrive=true,
bool commit=true, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename TiledMma,
typename ThrCopyA, typename TiledCopyA>
__forceinline__ __device__ void gemm(
TiledMma &tiled_mma,
Tensor0 &tCrA,
Tensor1 &tCsA,
Tensor2 const &tCrB,
Tensor3 &tCrC,
TiledCopyA const &tiled_copy_A,
ThrCopyA const &thr_copy_A) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
if (k_block < size<2>(tCrA) - 1) {
cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1));
}
int32_t * tCrA_data = reinterpret_cast<int32_t *>(tCrA(_,_,k_block).data());
int32_t * tCrA1_data = reinterpret_cast<int32_t *>(tCrA1(_,_,k_block).data());
int32_t * tCrA2_data = reinterpret_cast<int32_t *>(tCrA2(_,_,k_block).data());
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC);
cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC);
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}

View File

@@ -0,0 +1,351 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#include "helper.h"
#include "paddle/extension.h"
#include "w4afp8_gemm_template.h"
#include "w4afp8_gemm.h"
void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
assert(K % 64 == 0);
for (int b = 0; b < batch; ++b) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < K; k+=64) {
for (int k_inner = 0; k_inner < 32; ++k_inner) {
uint8_t temp = 0;
uint8_t left = weight[b * M * K + m * K + k + k_inner];
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
temp |= left << 4;
temp |= right;
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast<uint8_t*>(&temp);
}
}
}
}
}
template <typename T> class NVTraits;
template <> class NVTraits<__nv_fp8_e4m3> {
public:
typedef cutlass::float_e4m3_t data_t;
};
template <> class NVTraits<__nv_bfloat16>{
public:
typedef cutlass::bfloat16_t data_t;
};
template <> class NVTraits<half>{
public:
typedef cutlass::half_t data_t;
};
template <typename OutputType>
void DisPatchW4AFp8Gemm(
const cutlass::float_e4m3_t* input,
const cutlass::float_e4m3_t* weight,
const int64_t * tokens,
const float * input_row_sum,
const float * weight_scale,
OutputType * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int batch_size,
const int64_t M,
const int64_t K,
cudaStream_t stream) {
int kBlockN = (max_tokens + 15) / 16 * 16;
int TailN = 0;
if (kBlockN > 256) {
TailN = kBlockN % 256;
kBlockN = 256;
}
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
GEMM_SWITCH_BF16(
M, K, batch_size, token_padding_size, kBlockN, TailN,
weight,
input,
out,
weight_scale,
input_row_sum,
tokens,
max_tokens,
stream)
} else {
GEMM_SWITCH_FP16(
M, K, batch_size, token_padding_size, kBlockN, TailN,
weight,
input,
out,
weight_scale,
input_row_sum,
tokens,
max_tokens,
stream)
}
}
std::vector<paddle::Tensor> W4AFp8Gemm(
const paddle::Tensor& input,
const paddle::Tensor& weight,
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
const paddle::Tensor& input_row_sum,
const paddle::Tensor& weight_scale,
const int64_t token_padding_size,
const int64_t max_tokens,
const bool is_bfloat16) {
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2] * 2;
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
}
if (token_padding_size == 0) {
const int all_tokens = input.dims()[0];
if (is_bfloat16) {
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int64_t>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
} else {
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int64_t>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::half_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
}
} else {
if (is_bfloat16) {
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int64_t>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
} else {
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int64_t>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::half_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
}
}
}
template <typename InputType, typename OutputType>
void DisPatchW4AFp8GemmWrapper(
const InputType* input,
const InputType* weight,
const int64_t* total_rows_before_expert,
const float* input_row_sum,
const float* row_scale,
const float* weight_scale,
OutputType * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
cudaStream_t stream) {
using InType = typename NVTraits<InputType>::data_t;
using OutType = typename NVTraits<OutputType>::data_t;
DisPatchW4AFp8Gemm(
reinterpret_cast<const InType*>(input),
reinterpret_cast<const InType*>(weight),
total_rows_before_expert,
input_row_sum,
weight_scale,
reinterpret_cast<OutType*>(out),
token_padding_size,
max_tokens,
num_experts,
M,
K,
stream);
}
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2];
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place());
weight_convert(weight.data<uint8_t>(), weight_new.data<uint8_t>(), batch_size, M, K);
return {weight_new};
}
template <typename T, int kPackSize>
__global__ void permute_scale_kernel(
T* input_data,
const int numel) {
using LoadT = AlignedVector<T, kPackSize>;
LoadT input_vec;
LoadT dst_vec;
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
if (load_idx >= numel) {
return;
}
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
for (int i = 0; i < kPackSize; i+=2) {
dst_vec[i] = input_vec[i / 2];
dst_vec[i + 1] = input_vec[i / 2 + 8];
}
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
}
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
if (col % 16 != 0) {
PD_THROW("Only supported when col is divisible by 16.");
}
const int numel = row * col;
const int threads = 128;
const int kPackSize = 16;
const int grid_size = (numel / kPackSize + threads - 1) / threads;
if (scale.dtype() == paddle::DataType::BFLOAT16) {
permute_scale_kernel<phi::dtype::bfloat16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::bfloat16*>(scale.data<phi::dtype::bfloat16>()),
numel
);
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
permute_scale_kernel<phi::dtype::float16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
numel
);
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
permute_scale_kernel<float, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<float*>(scale.data<float>()),
numel
);
}
}
PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute)
.Inputs({"weight_scale"})
.Outputs({"permute_scale"})
.SetInplaceMap({{"weight_scale", "permute_scale"}})
.SetKernelFn(PD_KERNEL(W4AFp8GemmScalePermute));
PD_BUILD_STATIC_OP(w4afp8_gemm)
.Inputs({"input",
"weight",
"tokens",
"input_row_sum",
"weight_scale"})
.Outputs({"out"})
.Attrs({"token_padding_size: int64_t",
"max_tokens: int64_t",
"is_bfloat16: bool"})
.SetKernelFn(PD_KERNEL(W4AFp8Gemm));
PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert)
.Inputs({"weight"})
.Outputs({"converted_weight"})
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* weight,
const int64_t * tokens,
const float * input_row_sum,
const float * row_scale,
const float * weight_scale,
__nv_bfloat16 * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
cudaStream_t stream
);
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>(
const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* weight,
const int64_t * tokens,
const float * input_row_sum,
const float * row_scale,
const float * weight_scale,
half * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
cudaStream_t stream
);

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "helper.h"
std::vector<paddle::Tensor> W4AFp8Gemm(
const paddle::Tensor& input,
const paddle::Tensor& weight,
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
const paddle::Tensor& input_row_sum,
const paddle::Tensor& weight_scale,
const int64_t token_padding_size,
const int64_t max_tokens,
const bool is_bfloat16);
template <typename InputType, typename OutputType>
void DisPatchW4AFp8GemmWrapper(
const InputType* input,
const InputType* weight,
const int64_t * tokens,
const float * input_row_sum,
const float * row_scale,
const float * weight_scale,
OutputType * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
cudaStream_t stream);

View File

@@ -0,0 +1,252 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "kernel_traits.h"
#include "mainloop_fwd.h"
template <typename Ktraits>
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_gemm_kernel(
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int M = Ktraits::M;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int TAIL_N = Ktraits::TAIL_N;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using ElementOutput = typename Ktraits::ElementOutput;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
const int bidm = blockIdx.x;
const int bidn = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
if (tidx == 0) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
CollectiveMainloop collective_mainloop;
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
const int pre_fix_tokens = TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1]) : 0;
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] - pre_fix_tokens : mainloop_params.tokens[bidb];
if (bidn * kBlockN >= tokens) {
return;
}
float* input_row_sum = reinterpret_cast<float*>(
shared_memory + sizeof(typename Ktraits::SharedStorage));
if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(
mainloop_params,
pipeline,
smem_pipe_write,
shared_storage,
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
tidx);
} else {
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
PipelineState smem_pipe_read;
typename Ktraits::TiledMma tiled_mma;
typename Ktraits::TiledMma_TAIL tiled_mma_tail;
const int mma_tidx = tidx - NumCopyThreads;
const int lane_id = mma_tidx % 4 * 2;
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
if constexpr (TokenPackSize == 0) {
const int input_sum_idx = pre_fix_tokens + bidn * kBlockN;
if (mma_tidx < kBlockN) {
reinterpret_cast<float*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
}
} else {
const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN;
if (mma_tidx < kBlockN / 4) {
reinterpret_cast<float4*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float4*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
}
}
const int reamin_tokens = tokens - bidn * kBlockN;
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{}));
collective_mainloop.mma<TAIL_N>(
mainloop_params,
tiled_mma_tail,
pipeline,
smem_pipe_read,
shared_storage,
tSrS_tail,
mma_tidx);
collective_mainloop.store<TAIL_N>(
mainloop_params,
tSrS_tail,
shared_storage,
tiled_mma_tail,
input_row_sum + lane_id,
reinterpret_cast<const float*>(&weight_scale),
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
mma_tidx);
} else {
Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
collective_mainloop.mma<kBlockN>(
mainloop_params,
tiled_mma,
pipeline,
smem_pipe_read,
shared_storage,
tSrS,
mma_tidx);
collective_mainloop.store<kBlockN>(
mainloop_params,
tSrS,
shared_storage,
tiled_mma,
input_row_sum + lane_id,
reinterpret_cast<const float*>(&weight_scale),
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
mma_tidx);
}
}
}
template <int Batch>
auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(
make_shape(
static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Batch)),
make_stride(
static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
}
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
const float *input_row_sum, const int64_t * tokens, const int64_t max_tokens, cudaStream_t stream) {
using ElementOutput = typename Kernel_traits::ElementOutput;
using Element = typename Kernel_traits::Element;
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(A),
get_gmem_layout<Batch>(M, K / 2),
static_cast<Element const*>(B),
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens * Batch : TokenPackSize, K),
static_cast<ElementOutput*>(C),
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
weight_scale,
input_row_sum,
tokens
});
void *kernel;
kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = M_nums;
grid_dims.y = N_nums;
grid_dims.z = Batch;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(
launch_params, kernel, mainloop_params);
}

View File

@@ -293,6 +293,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu",
"gpu_ops/limit_content_len.cu",
]
# pd_disaggregation
@@ -494,6 +495,8 @@ elif paddle.is_compiled_with_cuda():
if cc >= 90 and nvcc_version >= 12.0:
# Hopper optmized mla
sources += find_end_files("gpu_ops/mla_attn", ".cu")
os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
setup(
name="fastdeploy_ops",

View File

@@ -0,0 +1,207 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
file_dir = "./gpu_ops/w4afp8_gemm/"
gemm_template_head = """
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
"""
gemm_template_case = """
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
const cutlass::float_e4m3_t * weight,
const cutlass::float_e4m3_t * input,
{cutlass_type} * out,
const float *weight_scale,
const float *input_row_sum,
const int64_t *tokens,
const int64_t max_tokens,
cudaStream_t stream);
"""
gemm_template_cu_head = """
#include "paddle/extension.h"
#include "w4afp8_gemm_template.h"
#include "w4afp8_gemm_kernel.hpp"
"""
gemm_template_cu_template = """
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
const cutlass::float_e4m3_t * weight,
const cutlass::float_e4m3_t * input,
{cutlass_type} * out,
const float *weight_scale,
const float *input_row_sum,
const int64_t *tokens,
const int64_t max_tokens,
cudaStream_t stream) {{
constexpr static int M = {M};
constexpr static int K = {K};
constexpr static int Batch = {BATCH};
constexpr static int TokenPackSize = {PADDING};
constexpr static int kBlockN = {N};
constexpr static int kBlockN_TAIL = {TAILN};
constexpr static int kBlockM = 128;
constexpr static int kBlockK = 128;
constexpr static int kNWarps = 4 + kBlockM / 16;
constexpr static int kStages = 5;
constexpr int kCluster = 1;
static_assert(K % kBlockK == 0);
constexpr int kTiles = K / kBlockK;
using Kernel_traits = Kernel_traits<
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
{cutlass_type}>;
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
Kernel_traits, M, K, Batch, TokenPackSize>
(weight, input, out, weight_scale,
input_row_sum, tokens, max_tokens, stream);
}}
"""
gemm_case = [
[8192, 3584, 8, 0], # eb45T ffn1
[8192, 3584, 8, 2048], # eb45T ffn1
[7168, 8192, 8, 0], # eb45T ffn2
[7168, 8192, 8, 2048], # eb45T ffn2
]
dtype = ["BF16", "FP16"]
def get_cutlass_type(type):
if type == "BF16":
return "cutlass::bfloat16_t"
elif type == "FP16":
return "cutlass::half_t"
template_head_file = open(f"{file_dir}w4afp8_gemm_template.h", "w")
template_head_file.write(gemm_template_head)
for type in dtype:
for case in gemm_case:
for n in range(16, 257, 16):
template_head_file.write(
gemm_template_case.format(
M=case[0],
K=case[1],
N=n,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=0,
cutlass_type=get_cutlass_type(type),
)
)
template_head_file.write(
gemm_template_case.format(
M=case[0],
K=case[1],
N=256,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=n - 16,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file = open(
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
)
template_cu_file.write(gemm_template_cu_head)
template_cu_file.write(
gemm_template_cu_template.format(
M=case[0],
K=case[1],
N=n,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=0,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file.close()
template_cu_file = open(
f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
)
template_cu_file.write(gemm_template_cu_head)
template_cu_file.write(
gemm_template_cu_template.format(
M=case[0],
K=case[1],
N=256,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=n - 16,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file.close()
for type in dtype:
template_head_file.write("\n")
template_head_file.write(
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
TYPE=type
)
)
template_head_file.write("\n")
for case in gemm_case:
for n in range(16, 257, 16):
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
)
)
template_head_file.write("\n")
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16
)
)
template_head_file.write("\n")
template_head_file.write(
""" } else { \\
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
} \\
}"""
)
template_head_file.close()

View File

@@ -294,16 +294,24 @@ class SpeculativeConfig:
self,
args,
):
# speculative method, choose in [None, "ngram_match", "mtp"]
self.method_list = ["ngram_match", "mtp"]
self.mtp_strategy_list = ["default", "with_ngram"]
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
self.method: Optional[str] = None
# mtp strategy in mtp-method
self.mtp_strategy = "default"
# the max length of speculative tokens
self.num_speculative_tokens: int = 1
# the model runner step of draft model/mtp...
self.num_model_steps: int = 1
# the max length of candidate tokens for speculative method
self.max_candidate_len: int = 5
# the max length of verify window for speculative method
self.verify_window: int = 2
# ngram match
self.max_ngram_size: int = 5
self.min_ngram_size: int = 2
# model for mtp/eagle/draft_model
self.model: Optional[str] = None
# quantization of model
@@ -390,6 +398,33 @@ class SpeculativeConfig:
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
def check_legality_parameters(
self,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if self.method is not None:
assert (
self.method in self.method_list
), f"speculative method only support {self.method_list} now, but get {self.method}."
assert (
self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5
), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}."
assert (
self.num_model_steps >= 1 and self.num_model_steps <= 5
), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}."
if self.method in ["mtp", "hybrid_mtp_ngram"]:
if self.num_speculative_tokens < self.num_model_steps:
logger.warning(
f"Get num_model_steps > num_speculative_tokens. Reset num_speculative_tokens to {self.num_model_steps}"
)
self.num_speculative_tokens = self.num_model_steps
assert (
self.mtp_strategy in self.mtp_strategy_list
), f"mtp_strategy_list only support {self.mtp_strategy_list}, but get {self.mtp_strategy}"
def __str__(self) -> str:
return self.to_json_string()

View File

@@ -283,6 +283,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_token_num_cpu[0].item(),
self.max_seq_len,
getattr(layer, "cache_quant_type_str", "none"),
self.rope_3d,
)
res = self.flash_attn_func(

View File

@@ -49,6 +49,7 @@ def gqa_rope_write_cache(
kv_token_num: int = 1,
max_seq_len: int = 0,
cache_quant_type: str = "none",
rope_3d: bool = False,
):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
@@ -81,6 +82,7 @@ def gqa_rope_write_cache(
kv_token_num,
max_seq_len,
cache_quant_type,
rope_3d,
)
return q, k, v, qkv_
else:

View File

@@ -46,7 +46,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert self.quant_method.name() == "wint8"

View File

@@ -51,7 +51,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
Paddle gcu create weight process.
"""
# bf16
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
@@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
@@ -312,7 +312,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"""
Paddle cutlass create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):

View File

@@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .fused_moe_cutlass_backend import CutlassW4A8MoEMethod, CutlassWeightOnlyMoEMethod
from .fused_moe_cutlass_backend import (
CutlassW4A8MoEMethod,
CutlassW4AFP8MoEMethod,
CutlassWeightOnlyMoEMethod,
)
from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod
from .moe import FusedMoE
__all__ = [
CutlassWeightOnlyMoEMethod,
CutlassW4A8MoEMethod,
CutlassW4AFP8MoEMethod,
FusedMoE,
TritonWeightOnlyMoEMethod,
]

View File

@@ -355,7 +355,7 @@ class EPPrefillRunner(EPRunner):
):
(
num_tokens_per_rank,
_,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
_,
@@ -365,6 +365,7 @@ class EPPrefillRunner(EPRunner):
dispatch_args = {
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": self.ep_engine.ep_config,

View File

@@ -31,6 +31,7 @@ if current_platform.is_cuda():
moe_expert_dispatch,
moe_expert_reduce,
noaux_tc,
w4afp8_gemm_scale_permute,
)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
@@ -75,7 +76,9 @@ class CutlassMoEMethod(MoEMethodBase):
Paddle cutlass create weight process.
"""
# bf16
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
@@ -98,6 +101,7 @@ class CutlassMoEMethod(MoEMethodBase):
token_nums_per_expert: paddle.Tensor,
expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1,
):
"""
Paddle Cutlass compute Fused MoE.
@@ -115,6 +119,7 @@ class CutlassMoEMethod(MoEMethodBase):
expert_idx_per_token,
self.moe_quant_type,
used_in_ep_low_latency,
estimate_total_token_nums,
)
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
@@ -128,6 +133,7 @@ class CutlassMoEMethod(MoEMethodBase):
expert_idx_per_token,
self.moe_quant_type,
used_in_ep_low_latency,
estimate_total_token_nums,
)
def apply_ep_prefill(
@@ -167,13 +173,13 @@ class CutlassMoEMethod(MoEMethodBase):
recv_x,
recv_topk_idx,
recv_topk_weights,
(self.up_gate_proj_in_scale if hasattr(self, "up_gate_proj_in_scale") else None),
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
)
if self.moe_quant_type != "w4a8":
# only w4a8 need expert_idx_per_token
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 and w4afp8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
expert_idx_per_token = None
else:
@@ -211,15 +217,17 @@ class CutlassMoEMethod(MoEMethodBase):
"""
Apply the EP decoder method.
"""
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts")
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
use_fp8 = self.moe_quant_type == "w4afp8"
# 2. EP Dispatch
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale, use_fp8=use_fp8
)
# 3. Compute ffn
if self.moe_quant_type == "w4a8":
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
num_local_experts, max_num, _ = permute_input.shape
expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile([1, max_num])
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
@@ -233,6 +241,7 @@ class CutlassMoEMethod(MoEMethodBase):
token_nums_per_expert.cast("int64"),
expert_idx_per_token,
True,
estimate_total_token_nums,
)
# 4. EP combine
@@ -295,7 +304,7 @@ class CutlassMoEMethod(MoEMethodBase):
topk_only_mode=False,
)
if self.moe_quant_type != "w4a8":
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
expert_idx_per_token = None
@@ -332,7 +341,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
self.moe_quant_type = "w4a8"
self.pack_num = 2
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
@@ -343,10 +352,10 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
)
)
up_gate_proj_weight_scale = []
@@ -355,22 +364,62 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
up_gate_proj_in_scale = []
down_proj_in_scale = []
if isinstance(state_dict, list):
state_dict = dict(state_dict)
if layer.ep_size > 1:
for expert_idx in range(layer.num_experts):
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
up_gate_proj_in_scale_all_experts.append(scale_tensor)
for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_weight_scale.append(
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
else down_proj_expert_weight_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
up_gate_proj_in_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_in_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
else down_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
@@ -396,7 +445,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
"""
Paddle cutlass create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
@@ -407,9 +458,13 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
quanted_weight = paddle.stack(weight_list, axis=0)
create_and_set_parameter(layer, weight_name, quanted_weight)
self.create_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
self.create_w4a8_scale_weights(
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
)
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
def create_w4a8_scale_weights(
self, layer: nn.Layer, weight_key_map: dict, state_dict: dict, logical_expert_ids, ep_rank_to_expert_id_list
):
"""
Get w4a8 weights from state dict and process them.
Args:
@@ -418,8 +473,15 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
state_dict (dict): The state dict.
"""
def _extract_scale_tensor(state_dict, key_template, expert_idx):
return get_tensor(state_dict.pop(key_template.format(expert_idx)))
def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx):
return get_tensor(
(
state_dict.pop(key_template.format(expert_idx))
if key_template.format(expert_idx) in state_dict
else key_template.format(expert_idx)
),
layer.fd_config.model_config.model,
)
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
processed_in_scale = 1 / paddle.concat(in_scales)
@@ -461,17 +523,249 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
# 2. Extract scale tensor from state dict
if layer.ep_size > 1:
for expert_idx in range(layer.num_experts):
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]
if scale_key_map["up_gate_proj_in_scale"].format(expert_idx) in state_dict
else scale_key_map["up_gate_proj_in_scale"].format(expert_idx)
),
layer.fd_config.model_config.model,
)
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
create_and_set_parameter(
layer, "up_gate_proj_in_scale_all_experts", paddle.concat(up_gate_proj_in_scales_all_experts)
)
for local_expert_idx in range(layer.num_local_experts):
expert_idx = local_expert_idx + layer.expert_id_offset
for expert_idx in logical_expert_ids:
for name, scale_key_template in scale_key_map.items():
scale_tensor = _extract_scale_tensor(state_dict, scale_key_template, expert_idx)
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
scale_weight_map[name].append(scale_tensor)
# 3. Process scale tensor and set to layer
in_scales = []
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name]))
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
_process_weight_scale(
weight_scale_name,
scale_weight_map[weight_scale_name],
in_scales[i],
)
class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
"""
w4a8 MoE Method
"""
def __init__(self, quant_config):
super().__init__(quant_config)
self.quant_config = quant_config
self.moe_quant_type = "w4afp8"
self.pack_num = 2
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
)
)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
up_gate_proj_in_scale_all_experts = []
up_gate_proj_in_scale = []
down_proj_in_scale = []
if isinstance(state_dict, list):
state_dict = dict(state_dict)
if layer.ep_size > 1:
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
up_gate_proj_in_scale_all_experts.append(scale_tensor)
for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(
(
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_weight_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
else down_proj_expert_weight_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
up_gate_proj_in_scale.append(
get_tensor(
(
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_in_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
else down_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0)
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0)
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0)
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
"up_gate_proj_in_scale": up_gate_proj_in_scale,
"down_proj_in_scale": down_proj_in_scale,
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
def create_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass create weight process.
"""
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
weight_list = []
for i in range(layer.num_local_experts):
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
weight_list.append(quant_weight)
quanted_weight = paddle.stack(weight_list, axis=0)
create_and_set_parameter(layer, weight_name, quanted_weight)
self.create_w4afp8_scale_weights(
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
)
def create_w4afp8_scale_weights(
self, layer: nn.Layer, weight_key_map: dict, state_dict: dict, logical_expert_ids, ep_rank_to_expert_id_list
):
"""
Get w4a8 weights from state dict and process them.
Args:
layer (nn.Layer): The layer to add parameters to.
weight_key_map (dict): The weight key map.
state_dict (dict): The state dict.
"""
def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx):
return get_tensor(
(
state_dict.pop(key_template.format(expert_idx))
if key_template.format(expert_idx) in state_dict
else key_template.format(expert_idx)
),
layer.fd_config.model_config.model,
)
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
processed_in_scale = 1 / paddle.concat(in_scales)
create_and_set_parameter(layer, name, processed_in_scale)
return processed_in_scale
def _permute_weight_scale(weight_scale: paddle.Tensor):
weight_scale = w4afp8_gemm_scale_permute(weight_scale)
return weight_scale
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
processed_weight_scale = (
paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None]
)
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
create_and_set_parameter(layer, name, processed_weight_scale)
# 1. Init scale containers and maps
up_gate_proj_weight_scales = []
down_proj_weight_scales = []
up_gate_proj_in_scales_all_experts = []
up_gate_proj_in_scales = []
down_proj_in_scales = []
scale_weight_map = {
"up_gate_proj_weight_scale": up_gate_proj_weight_scales,
"down_proj_weight_scale": down_proj_weight_scales,
"up_gate_proj_in_scale": up_gate_proj_in_scales,
"down_proj_in_scale": down_proj_in_scales,
}
scale_key_map = {
"up_gate_proj_weight_scale": weight_key_map.get("up_gate_proj_expert_weight_scale_key", None),
"down_proj_weight_scale": weight_key_map.get("down_proj_expert_weight_scale_key", None),
"up_gate_proj_in_scale": weight_key_map.get("up_gate_proj_expert_in_scale_key", None),
"down_proj_in_scale": weight_key_map.get("down_proj_expert_in_scale_key", None),
}
for name, value in scale_key_map.items():
if value is None:
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
# 2. Extract scale tensor from state dict
if layer.ep_size > 1:
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]
if scale_key_map["up_gate_proj_in_scale"].format(expert_idx) in state_dict
else scale_key_map["up_gate_proj_in_scale"].format(expert_idx)
),
layer.fd_config.model_config.model,
)
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
create_and_set_parameter(
layer, "up_gate_proj_in_scale_all_experts", paddle.concat(up_gate_proj_in_scales_all_experts)
)
for expert_idx in logical_expert_ids:
for name, scale_key_template in scale_key_map.items():
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
scale_weight_map[name].append(scale_tensor)
# 3. Process scale tensor and set to layer
@@ -498,7 +792,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
@@ -507,7 +801,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
@@ -541,7 +835,9 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
"""
Paddle cutlass create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):

View File

@@ -37,7 +37,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deepgemm create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
self.check(layer, up_gate_proj_weights, down_proj_weights)
@@ -62,7 +64,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
@@ -71,10 +73,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []

View File

@@ -143,7 +143,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
"""
Marlin MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert up_gate_proj_weights[0].shape == [

View File

@@ -56,7 +56,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
@@ -267,7 +267,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights"""
up_gate_proj_tensor, down_proj_tensor = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_tensor, down_proj_tensor, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert up_gate_proj_tensor[0].shape == [
layer.hidden_size,
layer.moe_intermediate_size * 2,
@@ -534,7 +534,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)

View File

@@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,

View File

@@ -36,7 +36,7 @@ class XPUMoEMethod(MoEMethodBase):
Paddle cutlass create weight process.
"""
# bf16
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
for weights in [up_gate_proj_weights, down_proj_weights]:
for idx, weight in enumerate(weights):
weights[idx] = weight.transpose([1, 0])
@@ -130,7 +130,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
"""
Paddle cutlass create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert up_gate_proj_weights[0].shape == [

View File

@@ -63,6 +63,7 @@ class FusedMoE(nn.Layer):
routed_scaling_factor: float = 1.0,
layer_idx: int = -1,
moe_tag: str = "",
redundant_table_manger: RedundantExpertManger = None,
weight_key_map: dict = {},
):
"""
@@ -118,15 +119,8 @@ class FusedMoE(nn.Layer):
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
self.quant_method = get_moe_method()
self.redundant_table_manger = None
self.redundant_table_manger = redundant_table_manger
if self.ep_size > 1:
if fd_config.model_config.enable_redundant_experts is True:
self.redundant_table_manger = RedundantExpertManger(
n_routed_experts=fd_config.model_config.moe_num_experts,
num_hidden_layers=fd_config.model_config.num_hidden_layers,
redundant_experts_num=fd_config.model_config.redundant_experts_num,
ep_size=self.ep_size,
)
self.quant_method.init_ep(self)
if fd_config.load_config.dynamic_load_weight:
@@ -223,6 +217,7 @@ class FusedMoE(nn.Layer):
state_dict: dict,
up_gate_proj_expert_weight_key: str,
down_proj_expert_weight_key: str,
is_rearrange: bool = False,
):
"""
Load experts weight from state_dict.
@@ -238,6 +233,7 @@ class FusedMoE(nn.Layer):
self.expert_id_offset + self.num_local_experts,
)
]
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
if self.redundant_table_manger is not None:
(
ep_rank_to_expert_id_list,
@@ -250,7 +246,13 @@ class FusedMoE(nn.Layer):
]
up_gate_proj_weights = []
down_proj_weights = []
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
if isinstance(state_dict, list):
state_dict = dict(state_dict)
is_ffn_merged = (
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
in state_dict
)
if is_ffn_merged:
for expert_idx in logical_expert_ids:
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
@@ -309,7 +311,7 @@ class FusedMoE(nn.Layer):
self.fd_config.model_config.model,
)
)
return up_gate_proj_weights, down_proj_weights, logical_expert_ids
return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
def extract_moe_ffn_weights(self, state_dict: dict):
"""
@@ -332,10 +334,12 @@ class FusedMoE(nn.Layer):
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
self.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
)
)
assert (
len(up_gate_proj_weights) == self.num_local_experts
@@ -344,7 +348,7 @@ class FusedMoE(nn.Layer):
len(down_proj_weights) == self.num_local_experts
), "down_proj_weights length should be equal to num_local_experts."
return up_gate_proj_weights, down_proj_weights
return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
def extract_gate_correction_bias(self, gate_correction_bias_key, state_dict):
"""
@@ -386,7 +390,7 @@ class FusedMoE(nn.Layer):
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", False):
self.quant_method.process_prequanted_weights(self, state_dict)
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
else:
self.quant_method.create_weights(self, state_dict)
else:

View File

@@ -36,7 +36,7 @@ class MixQuantConfig(QuantConfigBase):
image_moe_quant_type: str = None,
is_channel_wise: bool = False,
has_zero_point: bool = False,
is_permuted: bool = False,
is_permuted: bool = True,
) -> None:
super().__init__()
self.dense_quant_type = dense_quant_type
@@ -65,7 +65,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("image_moe_quant_type", None),
config.get("is_channel_wise", False),
config.get("has_zero_point", False),
config.get("is_permuted", False),
config.get("is_permuted", True),
)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -73,13 +73,13 @@ class MixQuantConfig(QuantConfigBase):
if layer.moe_tag == "Image":
return (
get_quantization_config(self.image_moe_quant_type)
.from_config(layer.fd_config.quant_config)
.from_config({"is_permuted": self.is_permuted})
.get_quant_method(layer)
)
else:
return (
get_quantization_config(self.moe_quant_type)
.from_config(layer.fd_config.quant_config)
.from_config({"is_permuted": self.is_permuted})
.get_quant_method(layer)
)
elif isinstance(layer, Attention):

View File

@@ -34,7 +34,7 @@ class W4A8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "W4A8Config":
is_permuted = getattr(config, "is_permuted", False)
is_permuted = config.get("is_permuted", True)
return cls(is_permuted)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:

View File

@@ -20,6 +20,7 @@ import paddle
import fastdeploy
from ..moe import FusedMoE
from .quant_base import QuantConfigBase, QuantMethodBase
QUANT_SCALING_FACTOR = 448
@@ -30,24 +31,32 @@ class W4AFP8Config(QuantConfigBase):
quantization config for weight 4bits and activation fp8
"""
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted) -> None:
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
self.is_permuted = is_permuted
def name(self) -> str:
return "w4afp8"
@classmethod
def from_config(cls, config: dict) -> "W4AFP8Config":
weight_scale_dict = config["weight_scale_dict"]
act_scale_dict = config["act_scale_dict"]
return cls(weight_scale_dict, act_scale_dict)
weight_scale_dict = config.get("weight_scale_dict", None)
act_scale_dict = config.get("act_scale_dict", None)
is_permuted = config.get("is_permuted", True)
return cls(weight_scale_dict, act_scale_dict, is_permuted)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
CutlassW4AFP8MoEMethod,
)
return CutlassW4AFP8MoEMethod(self)
return W4AFP8LinearMethod(self)

View File

@@ -33,6 +33,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
min_p_sampling,
top_k_top_p_sampling,
)
from fastdeploy.model_executor.ops.gpu import limit_content_len
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
@@ -304,6 +305,7 @@ class SpeculativeSampler(nn.Layer):
self.speculative_verify_window = fd_config.speculative_config.verify_window
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
self.fd_config = fd_config
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
@@ -382,6 +384,22 @@ class SpeculativeSampler(nn.Layer):
self.speculative_benchmark_mode,
)
if hasattr(self.fd_config.model_config, "think_end_id") and self.fd_config.model_config.think_end_id > 0:
limit_content_len(
share_inputs["accept_tokens"],
self.fd_config.model_config.think_end_id,
share_inputs["max_content_len"],
share_inputs["max_think_len"],
share_inputs["step_idx"],
sampling_metadata.eos_token_ids,
share_inputs["max_dec_len"],
share_inputs["limit_content_status"],
share_inputs["enable_thinking"],
share_inputs["accept_num"],
share_inputs["seq_lens_decoder"],
share_inputs["stop_flags"],
)
return None
@@ -429,8 +447,8 @@ class MTPSampler(nn.Layer):
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)

View File

@@ -112,7 +112,11 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
num_local_ffn_keys.append(down_proj_in_scale_key)
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
for j in range(fd_config.model_config.moe_num_experts):
num_experts = fd_config.model_config.moe_num_experts
if isinstance(num_experts, list):
num_experts = num_experts[0]
for j in range(num_experts):
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_in_scale_key)

View File

@@ -46,6 +46,7 @@ from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta
from fastdeploy.worker.experts_manager import RedundantExpertManger
class Ernie4_5_MLP(nn.Layer):
@@ -94,13 +95,15 @@ class Ernie4_5_MLP(nn.Layer):
class Ernie4_5_MoE(nn.Layer):
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
def __init__(
self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None
) -> None:
super().__init__()
moe_quant_type = ""
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
if moe_quant_type == "w4a8":
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
@@ -154,6 +157,7 @@ class Ernie4_5_MoE(nn.Layer):
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
weight_key_map=weight_key_map,
redundant_table_manger=redundant_table_manger,
)
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
@@ -170,6 +174,9 @@ class Ernie4_5_MoE(nn.Layer):
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
def update_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict, True)
def forward(self, hidden_states: paddle.Tensor):
out = self.fused_moe(hidden_states)
if self.num_shared_experts > 0:
@@ -226,6 +233,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
def __init__(
self,
fd_config: FDConfig,
redundant_table_manger: RedundantExpertManger = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -244,6 +252,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
self.mlp = Ernie4_5_MoE(
fd_config=fd_config,
layer_id=layer_id,
redundant_table_manger=redundant_table_manger,
prefix=f"{prefix}.mlp",
)
else:
@@ -273,6 +282,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
self.input_layernorm.load_state_dict(state_dict)
self.post_attention_layernorm.load_state_dict(state_dict)
def update_state_dict(self, state_dict):
self.mlp.update_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
@@ -313,6 +325,16 @@ class Ernie4_5_Model(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "ernie"
self.fd_config = fd_config
self.redundant_table_manger = None
if fd_config.model_config.enable_redundant_experts is True:
self.redundant_table_manger = RedundantExpertManger(
n_routed_experts=fd_config.model_config.moe_num_experts,
num_hidden_layers=fd_config.model_config.num_hidden_layers,
redundant_experts_num=fd_config.model_config.redundant_experts_num,
ep_size=fd_config.parallel_config.expert_parallel_size,
)
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
@@ -326,6 +348,7 @@ class Ernie4_5_Model(nn.Layer):
[
Ernie4_5_DecoderLayer(
fd_config=fd_config,
redundant_table_manger=self.redundant_table_manger,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
@@ -354,6 +377,22 @@ class Ernie4_5_Model(nn.Layer):
logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict)
def update_state_dict(self, state_dict):
"""
Update model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
for i in range(
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
logger.info(f"Start update layer {i}")
self.layers[i].update_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,

View File

@@ -244,6 +244,7 @@ class Ernie4_5_MTPModel(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
self.norm = fd_config.speculative_config.sharing_model.ernie.norm
self.layers = nn.LayerList(
[
@@ -314,6 +315,8 @@ class Ernie4_5_MTPModel(nn.Layer):
hidden_states = hidden_states + residual
hidden_states = self.norm(hidden_states)
return hidden_states

View File

@@ -46,7 +46,6 @@ from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
extract_text_token_output,
text_image_gather_scatter,
text_image_index_out,
)
@@ -99,8 +98,8 @@ class Ernie4_5_VLMoE(nn.Layer):
assert text_moe_layer_start_index <= text_moe_layer_end_index
moe_quant_type = ""
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
if moe_quant_type == "tensor_wise_fp8" or (
@@ -472,26 +471,6 @@ class Ernie4_5_VLModel(nn.Layer):
)
hidden_states = hidden_states + residual
# -----------------------
hidden_states = hidden_states.cast("float32")
score_text = hidden_states
if image_input is not None:
token_type_ids = token_type_ids.reshape([-1])
text_pos_shifted = token_type_ids[:token_num] == 0
score_text = hidden_states[text_pos_shifted.reshape([-1])]
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time.squeeze(-1), k=1)
hidden_states = extract_text_token_output(
max_seq_len,
max_seq_len_index.cast("int32"),
image_token_num.cast("int32"),
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
score_text,
).cast(self._dtype)
# -----------------------
out = self.norm(hidden_states)
return out

View File

@@ -59,7 +59,7 @@ else:
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_system_cache,
speculate_update_v3,
speculate_update,
step_paddle,
step_system_cache,
update_inputs,
@@ -288,7 +288,7 @@ def post_process_normal(
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
""""""
speculate_update_v3(
speculate_update(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.not_need_stop,

View File

@@ -281,12 +281,13 @@ class TokenProcessor:
def _compute_speculative_status(self):
# TODO(liuzichang): Supplement more statistics
interval = 50
interval = 10
if self.speculative_stats_step % interval == 0:
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
spec_logger.info(
f"Speculate global accept ratio(Accept draft_tokens/Generated tokens): {accept_ratio}"
f" total step: {self.total_step}. total output token num: {self.number_of_output_tokens}"
f" avarage accept len: {self.number_of_output_tokens / self.total_step}"
)
if self.cfg.speculative_config.method in ["mtp"]:

View File

@@ -45,6 +45,10 @@ class Proposer(ABC):
self.max_model_len = self.parallel_config.max_model_len
self.speculative_method = self.speculative_config.method
self.max_draft_token_num = self.speculative_config.num_speculative_tokens
self.num_model_steps = self.speculative_config.num_model_steps
self.max_ngram_size = self.speculative_config.max_ngram_size
self.min_ngram_size = self.speculative_config.min_ngram_size
spec_logger.info(f"Speculate config: {self.speculative_config}")

View File

@@ -35,6 +35,7 @@ from fastdeploy.model_executor.ops.gpu import (
draft_model_update,
eagle_get_hidden_states,
eagle_get_self_hidden_states,
hybrid_mtp_ngram,
mtp_save_first_token,
mtp_step_paddle,
share_external_data,
@@ -57,6 +58,8 @@ class MTPProposer(Proposer):
self._update_cfg(main_model)
self._load_model()
self.main_model_inputs = main_model_inputs
self.mtp_strategy = self.speculative_config.mtp_strategy
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
# [mixed, prefill, decoder]
self.role = "mixed"
@@ -266,12 +269,19 @@ class MTPProposer(Proposer):
)
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
self.model_inputs["rope_emb"] = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
)
if len(self.main_model_inputs["rope_emb"].shape) == 5:
self.model_inputs["rope_emb"] = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
)
else:
self.model_inputs["max_content_len"] = paddle.clone(self.main_model_inputs["max_content_len"])
self.model_inputs["max_think_len"] = paddle.clone(self.main_model_inputs["max_think_len"])
self.model_inputs["limit_content_status"] = paddle.clone(self.main_model_inputs["limit_content_status"])
self.model_inputs["enable_thinking"] = paddle.clone(self.main_model_inputs["enable_thinking"])
self.model_inputs["rope_emb"] = paddle.clone(self.main_model_inputs["rope_emb"])
# self.model_inputs["caches"] = self.cache_kvs
# Inherit generation hyperparameters from the main model for consistency
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
@@ -291,9 +301,12 @@ class MTPProposer(Proposer):
# Integrate the updated results in model forward
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
self.model_inputs["substep"] = 0
self.max_num_seqs = self.main_model_inputs["draft_tokens"].shape[0]
# Input tokens
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
self.model_inputs["draft_tokens"] = paddle.full(
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
)
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
@@ -311,10 +324,11 @@ class MTPProposer(Proposer):
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
if self.max_draft_token_num > 1:
if self.num_model_steps > 1:
self.last_seq_lens_this_time = paddle.full_like(
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
)
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
def insert_prefill_inputs(self, req_dicts: List[Request]):
"""
@@ -339,6 +353,7 @@ class MTPProposer(Proposer):
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
self.input_ids_len[idx] = length
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
length = len(request.prompt_token_ids)
@@ -432,15 +447,17 @@ class MTPProposer(Proposer):
self.model_inputs["step_idx"],
self.model_inputs["not_need_stop"],
self.model_inputs["batch_drop"],
self.model_inputs["pre_ids"],
self.main_model_inputs["accept_tokens"],
self.main_model_inputs["accept_num"],
self.main_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"],
self.main_model_inputs["seq_lens_decoder"],
self.main_model_inputs["step_idx"],
self.main_model_inputs["stop_flags"],
self.main_model_inputs["is_block_step"],
self.main_model_inputs["draft_tokens"],
self.max_draft_token_num,
self.num_model_steps,
self.speculative_method in ["eagle", "mtp"],
self.role == "prefill",
)
@@ -454,7 +471,7 @@ class MTPProposer(Proposer):
self.main_model_inputs["accept_num"],
self.main_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"],
self.max_draft_token_num,
self.num_model_steps,
)
if isinstance(target_hidden_states, list):
target_hidden_states = target_hidden_states[0]
@@ -494,7 +511,7 @@ class MTPProposer(Proposer):
"""
Main process for MTP inference
"""
for substep in range(self.max_draft_token_num):
for substep in range(self.num_model_steps):
if self.model_inputs["not_need_stop"]:
self.model_inputs["substep"] = substep
# Remove padding
@@ -514,6 +531,7 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
)
# Initialize forward meta data
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.model_inputs["cum_offsets"].copy_(cum_offsets, False)
@@ -540,7 +558,7 @@ class MTPProposer(Proposer):
eos_token_ids=self.model_inputs["eos_token_id"],
)
if self.max_draft_token_num > 1:
if self.num_model_steps > 1:
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
model_output = self.model(
@@ -574,7 +592,7 @@ class MTPProposer(Proposer):
self._post_process(sampled_token_ids)
if substep != self.max_draft_token_num - 1:
if substep != self.num_model_steps - 1:
target_hidden_states = self._get_self_hidden_states(hidden_states)
def _get_self_hidden_states(self, hidden_states):
@@ -646,11 +664,37 @@ class MTPProposer(Proposer):
self.max_draft_token_num,
)
def _extend_draft_token_with_ngram_match(self):
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
device = paddle.CUDAPinnedPlace()
draft_tokens = self.main_model_inputs["draft_tokens"].cpu()
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
hybrid_mtp_ngram(
self.model_inputs["input_ids"]._copy_to(device, True),
self.input_ids_len,
self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(),
self.main_model_inputs["actual_draft_token_num"].cpu(),
draft_tokens,
seq_lens_this_time,
seq_lens_decoder,
self.model_inputs["max_dec_len"].cpu(),
self.max_ngram_size,
self.min_ngram_size,
self.max_draft_token_num,
)
self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
def _run_impl(self, full_hidden_states):
""""""
target_hidden_states = self._prepare_inputs(full_hidden_states)
self._propose(target_hidden_states=target_hidden_states)
self._update_status()
if self.hybrid_mode:
self._extend_draft_token_with_ngram_match()
def is_chunk_prefill_enabled(self):
""""""

View File

@@ -1210,21 +1210,20 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["image_features"],
self.forward_meta,
)
hidden_states = model_output
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(self.share_inputs["output_padding_offset"] if self.speculative_decoding else None),
self.parallel_config.max_model_len,
)
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(self.share_inputs["output_padding_offset"] if self.speculative_decoding else None),
self.parallel_config.max_model_len,
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)

193
scripts/offline_w4a8.py Normal file
View File

@@ -0,0 +1,193 @@
import argparse
import json
import os
import re
import time
import paddle
from paddleformers.trainer import strtobool
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.transformers.model_utils import shard_checkpoint
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from paddleformers.utils.log import logger
from safetensors.numpy import save_file as safe_save_file
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.load_weight_utils import (
get_all_safetensors,
safetensors_weights_iterator,
)
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
def parse_arguments():
"""
parse_arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
required=True,
help="The directory of model.",
)
parser.add_argument(
"--output_dir",
default="merged_output",
required=True,
help="The directory of merged model output.",
)
parser.add_argument(
"--safe_serialization",
type=strtobool,
default="True",
help="Whether merge the model into safetensors format.",
)
parser.add_argument(
"--moe_quant_type",
default="w4a8",
choices=["w4a8", "w4afp8"],
help="The moe quant type of the model.",
)
return parser.parse_args()
def reorder():
def fn(weight, moe_quant_type):
from paddle.nn.quant import weight_quantize
quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80)
return quant_weight.cpu()
return fn
def deal_in_scale():
def fn(in_scale):
processed_in_scale = 1 / in_scale
return processed_in_scale
return fn
def deal_weight_scale():
def fn(weight_scale, processed_in_scale, moe_quant_type):
if moe_quant_type == "w4a8":
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
return processed_weight_scale
elif moe_quant_type == "w4afp8":
processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda())
return processed_weight_scale
return fn
# tmp support w4a8
def deal_quant(state_dict, save_state_dict, moe_quant_type):
param_mapping = [
# pattern,fn
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
]
for pattern, fn in param_mapping:
for key in list(state_dict.keys()):
# print(f"deal {key}")
match = re.search(pattern, key)
if match:
# print(f"{key} is match")
weight_or_scale = state_dict.pop(key)
if "weight_scale" in key:
in_scale_key = key.replace("weight_scale", "activation_scale")
in_scale = save_state_dict[in_scale_key]
save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type)
elif "activation_scale" in key:
save_state_dict[key] = fn(weight_or_scale)
else:
save_state_dict[key] = fn(weight_or_scale, moe_quant_type)
def save_safetensors(state_dict, args):
"""
save_safetensors
"""
logger.info("Move to numpy.")
for k in list(state_dict.keys()):
if isinstance(state_dict[k], paddle.Tensor):
state_dict[k] = state_dict.pop(k).cpu().numpy()
logger.info("Save safetensors files.")
shards, index = shard_checkpoint(
state_dict,
max_shard_size="5GB",
weights_name=SAFE_WEIGHTS_NAME,
shard_format="naive",
)
for shard_file, shard in shards.items():
save_file = os.path.join(args.output_dir, shard_file)
logger.info(f"Saving {save_file}")
safe_save_file(shard, save_file, metadata={"format": "np"})
save_index_file = os.path.join(args.output_dir, SAFE_WEIGHTS_INDEX_NAME)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2) + "\n"
f.write(content)
def main():
"""
main
"""
args = parse_arguments()
pretrained_config, _ = PretrainedConfig.get_config_dict(args.model_name_or_path)
pretrained_config = PretrainedConfig.from_dict(pretrained_config)
vocab_file_names = [
"tokenizer.model",
"spm.model",
"ernie_token_100k.model",
]
for i in range(len(vocab_file_names)):
if os.path.exists(os.path.join(args.model_name_or_path, vocab_file_names[i])):
ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
break
tokenizer = ErnieBotTokenizer.from_pretrained(args.model_name_or_path)
_, safetensor_files = get_all_safetensors(args.model_name_or_path)
weights_iterator = safetensors_weights_iterator(safetensor_files)
state_dict = {}
save_state_dict = {}
start = time.perf_counter()
for k, v in weights_iterator:
state_dict[k] = get_tensor(v).cpu()
end = time.perf_counter()
logger.info("Finish Quantize.")
logger.info(f"load and quantize took : {end - start:.6f} seconds")
deal_quant(state_dict, save_state_dict, args.moe_quant_type)
for key in list(state_dict.keys()):
save_state_dict[key] = state_dict.pop(key)
logger.info("Begin to save model")
os.makedirs(args.output_dir, exist_ok=True)
start = time.perf_counter()
if not args.safe_serialization:
paddle.save(
save_state_dict,
os.path.join(args.output_dir, "model_state.pdparams"),
)
else:
save_safetensors(save_state_dict, args)
pretrained_config.is_permuted = True
pretrained_config.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
end = time.perf_counter()
logger.info(f"save model took: {end - start:.6f} seconds")
logger.info("Finish.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,36 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -ex
rm -rf log
rm -f core*
export devices=0
export CUDA_VISIBLE_DEVICES=${devices}
model_path=${1:-"/PATH/MODEL_PATH"}
output_path=${2:-"/PATH/OUTPUT_MODEL"}
moe_quant_type=${3:-"w4a8"}
for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
unset ${name}
done
export PADDLE_TRAINER_ID=0
export PADDLE_TRAINERS_NUM=1
export TRAINER_INSTANCES_NUM=1
export TRAINER_INSTANCES=`hostname -i`
self_ip=`hostname -i`
python offline_w4a8.py \
--model_name_or_path ${model_path} \
--output_dir ${output_path} \
--safe_serialization "True" \
--moe_quant_type ${moe_quant_type}

View File

@@ -0,0 +1,75 @@
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import hybrid_mtp_ngram
class TestNgramMatchMixed(unittest.TestCase):
def setUp(self):
self.max_bsz = 2
self.max_draft_tokens = 5
self.max_len = 32
self.max_dec_len = 10
self.max_ngram_size = 5
self.min_ngram_size = 2
# 初始化输入 tensor
self.input_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu()
self.input_ids_len = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu()
self.pre_ids = paddle.full(shape=[self.max_bsz, self.max_len], fill_value=-1, dtype="int64").cpu()
self.step_idx = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int64").cpu()
self.draft_token_num = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
self.draft_tokens = paddle.full(
shape=[self.max_bsz, self.max_draft_tokens + 1],
fill_value=-1,
dtype="int64",
).cpu()
self.seq_lens_this_time = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
self.seq_lens_decoder = paddle.full(shape=[self.max_bsz, 1], fill_value=-1, dtype="int32").cpu()
self.max_dec_len = paddle.full(
shape=[self.max_bsz, 1],
fill_value=self.max_dec_len,
dtype="int64",
).cpu()
# 设置具体数据
self.input_ids[:, :10] = np.arange(0, 10)
self.input_ids_len[:] = 10
pre_ids_np = np.array([10, 9, 8, 7, 6, 10, 9, 8, 7], dtype="int32")
self.pre_ids[:, : pre_ids_np.shape[0]] = pre_ids_np
self.step_idx[:] = 8
self.draft_token_num[:] = 5
self.draft_tokens[:, :2] = np.array([8, 7])
self.seq_lens_this_time[:] = 2
self.seq_lens_decoder[:] = 12
self.max_dec_len[:] = 512
# 期望结果
self.ref_seq_lens_this_time = np.array([[6], [6]], dtype="int32")
self.ref_draft_tokens = np.array([[8, 7, 6, 10, 9, 8], [8, 7, 6, 10, 9, 8]], dtype="int64")
def test_ngram_match_mixed(self):
hybrid_mtp_ngram(
self.input_ids,
self.input_ids_len,
self.pre_ids,
self.step_idx,
self.draft_token_num,
self.draft_tokens,
self.seq_lens_this_time,
self.seq_lens_decoder,
self.max_dec_len,
self.max_ngram_size,
self.min_ngram_size,
self.max_draft_tokens,
)
np.testing.assert_allclose(self.seq_lens_this_time.numpy(), self.ref_seq_lens_this_time)
np.testing.assert_allclose(self.draft_tokens.numpy(), self.ref_draft_tokens)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,103 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert
def w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N):
all_tokens = int(tokens.sum())
out = paddle.zeros([all_tokens, N], dtype="bfloat16")
pre_fix_token = 0
for i in range(BATCH):
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
pre_fix_token += tokens[i]
return out
def permute_scale(weight_scale):
weight_scale = weight_scale.reshape([BATCH, N])
temp = paddle.zeros([16])
for b in range(BATCH):
for n in range(0, N, 16):
temp[:] = weight_scale[b, n : n + 16]
for j in range(0, 16, 2):
weight_scale[b, n + j] = temp[j // 2]
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
return weight_scale
paddle.seed(0)
tokens_per_group = 32
N = 8192
K = 3584
BATCH = 8
TokenPadding = 0
tokens = [tokens_per_group] * BATCH
tokens_perfix_sum = np.cumsum(tokens)
tokens = paddle.to_tensor(tokens, dtype="int64")
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int64")
all_tokens = int(tokens.sum())
input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
input_bf16 = input_fp8.astype("bfloat16")
weight = paddle.randn([BATCH, N, K], dtype="bfloat16") / 10
weight_scale = 7 / weight.abs().max(axis=-1).reshape([BATCH, N, 1])
weight_quant = (weight * weight_scale).astype("int") + 7
weight_quant = paddle.clip(weight_quant, 0, 14)
weight_quant = weight_quant.astype("bfloat16")
weight_dequant_scale = 1 / weight_scale.astype("float32")
input_row_sum = input_bf16.sum(axis=1) * -7 / 512
max_tokens = int(tokens.max())
out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N)
weight_dequant_scale = paddle.to_tensor(permute_scale(weight_dequant_scale) * 512)
weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu())
if TokenPadding == 0:
out_cuda = w4afp8_gemm(
input_fp8,
weight_int4.cuda(),
tokens_perfix_sum,
input_row_sum.astype("float32"),
weight_dequant_scale.astype("float32"),
int(TokenPadding),
max_tokens,
True,
)
else:
out_cuda = w4afp8_gemm(
input_fp8,
weight_int4.cuda(),
tokens,
input_row_sum.astype("float32"),
weight_dequant_scale.astype("float32"),
int(TokenPadding),
max_tokens,
True,
)
gap = (out_cuda - out_naive).abs()
assert float(gap.mean()) < 0.07