[Speculative Decoding][MTP]Support attn mask offset (#4641)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* [MTP]Merge support attn (#4591)

* support mask_offset in speculate decoding

* fix dummpy run output

* add unit test

* fix unit test import

* support attn_mask_offset in mtp mode

* add update_attn_mask op

* fix unit test && fix code-style
This commit is contained in:
freeliuzc
2025-11-03 10:08:01 +08:00
committed by GitHub
parent f44f4bafd1
commit 11398790d3
13 changed files with 638 additions and 111 deletions

View File

@@ -782,7 +782,8 @@ void SpeculateUpdate(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& is_block_step,
const paddle::Tensor& stop_nums);
const paddle::Tensor& stop_nums,
const paddle::Tensor& mask_rollback);
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all,
const paddle::Tensor& accept_tokens,
@@ -1047,6 +1048,18 @@ void SpeculateGetTargetLogits(const paddle::Tensor& target_logits,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num);
std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& ids_remove_padding,
const paddle::Tensor& seq_lens_this_time, // only on cpu
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& attn_mask_offsets_full,
const paddle::Tensor& attn_mask_offsets_decoder,
const paddle::Tensor& is_block_step,
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num",
&GetExpertTokenNum,
@@ -1632,4 +1645,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_get_target_logits",
&SpeculateGetTargetLogits,
"speculate_get_target_logits function");
m.def("update_attn_mask_offsets",
&UpdateAttnMaskOffsets,
"update attention mask");
}

View File

@@ -16,115 +16,116 @@
template <int THREADBLOCK_SIZE>
__global__ void speculate_update(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *not_need_stop,
int64_t *draft_tokens,
int *actual_draft_token_nums,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_this_time,
const bool *is_block_step,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_draft_tokens) {
const int bid = threadIdx.x;
const int accept_num_now = accept_num[bid];
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
if (stop_flags[bid]) {
stop_flag_now_int = 1;
}
if (seq_lens_encoder[bid] == 0) {
seq_lens_decoder[bid] += accept_num_now;
}
if (seq_lens_this_time[bid] > 1 &&
seq_lens_encoder[bid] ==
0) { // 对于append模式需要根据接收与否确定是否要降低下次draft
// token的数量
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
if (accept_num_now - 1 == current_actual_draft_token_num) {
if (current_actual_draft_token_num + 2 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 2;
} else if (current_actual_draft_token_num + 1 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 1;
} else {
actual_draft_token_nums[bid] = max_draft_tokens - 1;
}
} else {
actual_draft_token_nums[bid] =
actual_draft_token_nums[bid] - 1 >= 1
? actual_draft_token_nums[bid] - 1
: 1;
}
}
if (seq_lens_encoder[bid] != 0) {
seq_lens_decoder[bid] += seq_lens_encoder[bid];
seq_lens_encoder[bid] = 0;
}
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
int *seq_lens_decoder,
bool *not_need_stop,
int64_t *draft_tokens,
int *actual_draft_token_nums,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_this_time,
const bool *is_block_step,
const int64_t *stop_nums,
int *mask_rollback,
const int real_bsz,
const int max_bsz,
const int max_draft_tokens) {
const int bid = threadIdx.x;
const int accept_num_now = accept_num[bid];
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
if (stop_flags[bid]) {
stop_flag_now_int = 1;
mask_rollback[bid] = 0;
} else if (seq_lens_encoder[bid] == 0) { // decoder
seq_lens_decoder[bid] += accept_num_now;
mask_rollback[bid] = seq_lens_this_time[bid] - accept_num_now;
} else { // encoder
mask_rollback[bid] = 0;
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < stop_nums[0];
if (seq_lens_this_time[bid] > 1 &&
seq_lens_encoder[bid] ==
0) { // 对于append模式需要根据接收与否确定是否要降低下次draft
// token的数量
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
if (accept_num_now - 1 == current_actual_draft_token_num) {
if (current_actual_draft_token_num + 2 <= max_draft_tokens - 1) {
actual_draft_token_nums[bid] = current_actual_draft_token_num + 2;
} else if (current_actual_draft_token_num + 1 <= max_draft_tokens - 1) {
actual_draft_token_nums[bid] = current_actual_draft_token_num + 1;
} else {
actual_draft_token_nums[bid] = max_draft_tokens - 1;
}
} else {
actual_draft_token_nums[bid] = actual_draft_token_nums[bid] - 1 >= 1
? actual_draft_token_nums[bid] - 1
: 1;
}
}
if (seq_lens_encoder[bid] != 0) {
seq_lens_decoder[bid] += seq_lens_encoder[bid];
seq_lens_encoder[bid] = 0;
}
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums,
const paddle::Tensor &mask_rollback) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
constexpr int BlockSize = 512;
constexpr int BlockSize = 512;
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int *>(actual_draft_token_nums.data<int>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_this_time.data<int>(),
is_block_step.data<bool>(),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_draft_tokens);
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int *>(actual_draft_token_nums.data<int>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_this_time.data<int>(),
is_block_step.data<bool>(),
stop_nums.data<int64_t>(),
const_cast<int *>(mask_rollback.data<int>()),
real_bsz,
max_bsz,
max_draft_tokens);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_update)
@@ -138,15 +139,18 @@ PD_BUILD_STATIC_OP(speculate_update)
"stop_flags",
"seq_lens_this_time",
"is_block_step",
"stop_nums"})
"stop_nums",
"mask_rollback"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"not_need_stop_out",
"draft_tokens_out",
"actual_draft_token_nums_out"})
"actual_draft_token_nums_out",
"mask_rollback_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"not_need_stop", "not_need_stop_out"},
{"draft_tokens", "draft_tokens_out"},
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
{"actual_draft_token_nums", "actual_draft_token_nums_out"},
{"mask_rollback", "mask_rollback_out"}})
.SetKernelFn(PD_KERNEL(SpeculateUpdate));

View File

@@ -0,0 +1,141 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
__global__ void update_attn_mask_offsets_kernel(
int* attn_mask_offsets,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int* cu_seqlens_q,
const int* attn_mask_offsets_full,
int* attn_mask_offsets_decoder,
const bool* is_block_step,
int* decode_states,
const int* mask_rollback,
const int real_bsz,
const int max_model_len,
const int decode_states_len) {
int tid = threadIdx.x;
for (int bid = tid; bid < real_bsz; bid += blockDim.x) {
int seq_len_this_time = seq_lens_this_time[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid];
int query_start_id = cu_seqlens_q[bid];
const int* attn_mask_offsets_full_now =
attn_mask_offsets_full + bid * max_model_len;
int* decode_states_now = decode_states + bid * decode_states_len;
if (!is_block_step[bid]) {
if (seq_len_encoder == 0 && seq_len_decoder == 0) {
// Status: stop
} else if (seq_len_encoder > 0) {
for (int i = 0; i < seq_len_this_time; i++) {
if (*decode_states_now == 2 && seq_len_decoder > 0) {
// Status: vision generate phase
attn_mask_offsets[(query_start_id + i) * 2 + 1] =
seq_len_decoder + seq_len_this_time;
} else {
// Status: prefill -- normal or chunk_prefill
attn_mask_offsets[(query_start_id + i) * 2 + 1] =
attn_mask_offsets_full_now[i] + 1;
}
}
} else if (seq_len_decoder > 0) {
// Status: decoder -- normal or chunk_prefill
// TODO: support speculative decoding.
attn_mask_offsets_decoder[bid] -= mask_rollback[bid];
for (int i = 0; i < seq_len_this_time; i++) {
attn_mask_offsets[(query_start_id + i) * 2 + 1] =
attn_mask_offsets_decoder[bid] + 1 + i;
}
attn_mask_offsets_decoder[bid] += seq_len_this_time;
// Speculative decoding in text_generation
if (seq_len_this_time > 1) {
for (int i = 0; i < decode_states_len; i++) {
if (i < seq_len_this_time) {
decode_states_now[i] = 0;
} else {
decode_states_now[i] = -1;
}
}
}
}
}
}
}
std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& ids_remove_padding,
const paddle::Tensor& seq_lens_this_time, // only on cpu
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& attn_mask_offsets_full,
const paddle::Tensor& attn_mask_offsets_decoder,
const paddle::Tensor& is_block_step,
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback) {
int max_model_len = attn_mask_offsets_full.shape()[1];
int real_bsz = seq_lens_this_time.shape()[0];
int batch_seq_lens = ids_remove_padding.shape()[0];
int decode_states_len = decode_states.shape()[1];
auto attn_mask_offsets = paddle::full({batch_seq_lens * 2},
0,
paddle::DataType::INT32,
ids_remove_padding.place());
// launch config
int blockSize = 512;
update_attn_mask_offsets_kernel<<<1,
blockSize,
0,
ids_remove_padding.stream()>>>(
attn_mask_offsets.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
attn_mask_offsets_full.data<int>(),
const_cast<int*>(attn_mask_offsets_decoder.data<int>()),
is_block_step.data<bool>(),
const_cast<int*>(decode_states.data<int>()),
mask_rollback.data<int>(),
real_bsz,
max_model_len,
decode_states_len);
return {attn_mask_offsets};
}
PD_BUILD_STATIC_OP(update_attn_mask_offsets)
.Inputs({"ids_remove_padding",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"cu_seqlens_q",
"attn_mask_offsets_full",
"attn_mask_offsets_decoder",
"is_block_step",
"decode_states",
"mask_rollback"})
.Outputs({"attn_mask_offsets", "decode_states_out"})
.SetInplaceMap({{"decode_states", "decode_states_out"}})
.SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets));