mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature][MTP]support new speculative decoding method named hybrid mtp with ngram (#3610)
This commit is contained in:
@@ -26,6 +26,7 @@ __global__ void process_splitwise_prefill(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
@@ -36,11 +37,12 @@ __global__ void process_splitwise_prefill(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
@@ -93,6 +95,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
@@ -103,11 +106,12 @@ __global__ void draft_model_preprocess_kernel(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len) {
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
@@ -124,6 +128,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
|
||||
const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
|
||||
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
@@ -137,14 +142,12 @@ __global__ void draft_model_preprocess_kernel(
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
if (base_model_step_idx_now == 0) {
|
||||
seq_lens_this_time[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
} else if (seq_lens_encoder[tid] > 0) {
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
@@ -161,34 +164,17 @@ __global__ void draft_model_preprocess_kernel(
|
||||
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
|
||||
} else {
|
||||
// 2: Last base model generated token and first MTP token
|
||||
seq_lens_decoder[tid] -= (base_model_seq_len_this_time - 2);
|
||||
step_idx[tid] -= (base_model_seq_len_this_time - 2);
|
||||
seq_lens_decoder[tid] -= num_model_step - 1;
|
||||
step_idx[tid] -= num_model_step - 1;
|
||||
}
|
||||
for (int i = 0; i < accept_num_now; i++) {
|
||||
draft_tokens_now[i] = accept_tokens_now[i];
|
||||
const int pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i);
|
||||
const int64_t accept_token = accept_tokens_now[i];
|
||||
pre_ids_now[pre_id_pos] = accept_token;
|
||||
}
|
||||
seq_lens_this_time[tid] = accept_num_now;
|
||||
}
|
||||
// (liuzichang): Temperary Reserved for debug
|
||||
// else if (accept_num_now <=
|
||||
// max_draft_token) /*Accept partial draft tokens*/ {
|
||||
// // Base Model reject stop
|
||||
// if (stop_flags[tid]) {
|
||||
// stop_flags[tid] = false;
|
||||
// seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||
// step_idx[tid] = base_model_step_idx[tid];
|
||||
// } else {
|
||||
// seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||
// step_idx[tid] -= max_draft_token - accept_num_now;
|
||||
// }
|
||||
// int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
// draft_tokens_now[0] = modified_token;
|
||||
// seq_lens_this_time[tid] = 1;
|
||||
|
||||
// } else /*Accept all draft tokens*/ {
|
||||
// draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||
// seq_lens_this_time[tid] = 2;
|
||||
// }
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
@@ -215,6 +201,7 @@ void DispatchRunner(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
@@ -225,11 +212,12 @@ void DispatchRunner(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
@@ -244,6 +232,7 @@ void DispatchRunner(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -254,11 +243,12 @@ void DispatchRunner(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
@@ -271,6 +261,7 @@ void DispatchRunner(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -281,11 +272,12 @@ void DispatchRunner(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len);
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,6 +292,7 @@ void DispatchTokenMode(
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
@@ -310,11 +303,12 @@ void DispatchTokenMode(
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
@@ -329,6 +323,7 @@ void DispatchTokenMode(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -339,11 +334,12 @@ void DispatchTokenMode(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
@@ -358,6 +354,7 @@ void DispatchTokenMode(
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -368,11 +365,12 @@ void DispatchTokenMode(
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
@@ -390,6 +388,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -399,13 +398,14 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
int draft_tokens_len = draft_tokens.shape()[1];
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
constexpr int BlockSize = 512;
|
||||
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
|
||||
@@ -423,6 +423,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
@@ -433,11 +434,12 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
max_draft_token,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
|
||||
@@ -458,6 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"base_model_seq_lens_this_time",
|
||||
@@ -475,8 +478,9 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder_out",
|
||||
"step_idx_out",
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out"})
|
||||
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
@@ -485,5 +489,6 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"batch_drop", "batch_drop_out"}})
|
||||
{"batch_drop", "batch_drop_out"},
|
||||
{"pre_ids", "pre_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
|
||||
|
@@ -63,10 +63,9 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
for (int i = 0; i < seq_len_this_time; ++i) {
|
||||
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
|
||||
}
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
pre_ids_now[step_idx[tid]] = token_this_time;
|
||||
|
||||
|
||||
} else {
|
||||
token_this_time = next_tokens_start[0];
|
||||
|
@@ -49,9 +49,7 @@ __global__ void ComputeOrderKernel(
|
||||
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
// 2. base model encoder. Base step=0
|
||||
} else if (cur_base_model_seq_lens_encoder != 0) {
|
||||
// 3. New end
|
||||
// 2. Base model stop at last verify-step.
|
||||
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: base=0. draft !=0 \n", i);
|
||||
|
@@ -0,0 +1,214 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
|
||||
int sum_mixed(const int *value, int num) {
|
||||
int sum_value = 0;
|
||||
for (int i = 0; i <= num; i++) {
|
||||
sum_value += value[i];
|
||||
}
|
||||
return sum_value;
|
||||
}
|
||||
|
||||
void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
|
||||
const int64_t *input_ids_len,
|
||||
const int64_t *pre_ids,
|
||||
const int64_t *step_idx,
|
||||
const int *draft_token_num,
|
||||
int64_t *draft_tokens,
|
||||
int32_t *seq_lens_this_time,
|
||||
int32_t *seq_lens_decoder,
|
||||
int64_t *max_dec_len,
|
||||
int64_t input_ids_stride,
|
||||
int64_t pre_ids_stride,
|
||||
int64_t draft_tokens_stride,
|
||||
int64_t max_batch_size,
|
||||
int max_ngram_size = 3,
|
||||
int min_ngram_size = 1,
|
||||
const int max_draft_tokens = 10) {
|
||||
int threshold = 1024;
|
||||
// dynamic in future
|
||||
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
|
||||
if (env_var) {
|
||||
threshold = std::stoi(env_var);
|
||||
}
|
||||
int unprocessed_batch_size = 0;
|
||||
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
|
||||
if (seq_lens_decoder[batch_idx] > 0) {
|
||||
unprocessed_batch_size++;
|
||||
}
|
||||
}
|
||||
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
|
||||
const int ori_seq_len_this_time = seq_lens_this_time[batch_idx];
|
||||
int max_draft_tokens_query = std::min(static_cast<int64_t>(
|
||||
max_draft_tokens - ori_seq_len_this_time + 1), max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
|
||||
|
||||
if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
|
||||
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
|
||||
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
|
||||
const int64_t cur_step_idx = step_idx[batch_idx];
|
||||
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
|
||||
unprocessed_batch_size--;
|
||||
|
||||
auto sum_token_num = sum_mixed(seq_lens_this_time, batch_idx);
|
||||
int left_min_token_num = unprocessed_batch_size;
|
||||
|
||||
if (sum_token_num + max_draft_tokens_query + left_min_token_num > threshold) {
|
||||
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
|
||||
max_draft_tokens_query = std::min(max_draft_tokens_query, tmp_max_draft_tokens);
|
||||
}
|
||||
|
||||
if (sum_token_num + left_min_token_num >= threshold - 1) {
|
||||
continue;
|
||||
}
|
||||
bool match_global = false;
|
||||
// apply ngram_match in input_ids
|
||||
for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size && !match_global; --ngram_size) {
|
||||
// Extract the last n tokens as our search ngram
|
||||
if (cur_step_idx < ngram_size) {
|
||||
continue;
|
||||
}
|
||||
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
|
||||
|
||||
// Iterate through sliding windows of size ngram_size
|
||||
// bool match_input = false;
|
||||
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global; ++i) {
|
||||
// Check if the current window matches the ngram
|
||||
bool match_local = true;
|
||||
for (int j = 0; j < ngram_size; j++) {
|
||||
if (ngram[j] != cur_input_ids[i + j]) {
|
||||
match_local = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_local) {
|
||||
int64_t start_idx = i + ngram_size;
|
||||
int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_input_ids_len);
|
||||
if (start_idx >= end_idx)
|
||||
continue;
|
||||
|
||||
int64_t cur_draft_token_num = end_idx - start_idx;
|
||||
|
||||
seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
|
||||
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
|
||||
// To break the current batch_idx for-loop
|
||||
match_global = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// apply ngram_match in generated tokens
|
||||
if (!match_global) {
|
||||
for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global; ++i) {
|
||||
// Check if the current window matches the ngram
|
||||
bool match_local = true;
|
||||
|
||||
for (int j = 0; j < ngram_size; j++) {
|
||||
if (ngram[j] != cur_pre_ids[i + j]) {
|
||||
match_local = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_local) {
|
||||
int64_t start_idx = i + ngram_size;
|
||||
int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_step_idx);
|
||||
|
||||
int64_t cur_draft_token_num = end_idx - start_idx;
|
||||
|
||||
if (start_idx >= end_idx)
|
||||
continue;
|
||||
// printf("match in Output with Ngram_size %d. %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx, end_idx);
|
||||
|
||||
seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num;
|
||||
memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
|
||||
match_global = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HybridMtpNgram(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int min_ngram_size,
|
||||
const int max_draft_tokens) {
|
||||
|
||||
auto input_ids_shape = input_ids.shape();
|
||||
const int64_t input_ids_stride = input_ids_shape[1];
|
||||
|
||||
auto pre_ids_shape = pre_ids.shape();
|
||||
const int64_t pre_ids_stride = pre_ids_shape[1];
|
||||
|
||||
auto draft_tokens_shape = draft_tokens.shape();
|
||||
const int64_t draft_tokens_stride = draft_tokens_shape[1];
|
||||
|
||||
const int64_t max_batch_size = seq_lens_this_time.shape()[0];
|
||||
|
||||
find_candidate_pred_tokens_mixed(input_ids.data<int64_t>(),
|
||||
input_ids_len.data<int64_t>(),
|
||||
pre_ids.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
draft_token_num.data<int>(),
|
||||
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
|
||||
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
|
||||
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
|
||||
input_ids_stride,
|
||||
pre_ids_stride,
|
||||
draft_tokens_stride,
|
||||
max_batch_size,
|
||||
max_ngram_size,
|
||||
min_ngram_size,
|
||||
max_draft_tokens);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
|
||||
.Inputs({"input_ids",
|
||||
"input_ids_len",
|
||||
"pre_ids",
|
||||
"step_idx",
|
||||
"draft_token_num",
|
||||
"draft_tokens",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"max_dec_len"})
|
||||
.Attrs({"max_ngram_size: int", "min_ngram_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
|
||||
.SetKernelFn(PD_KERNEL(HybridMtpNgram))
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}});
|
Reference in New Issue
Block a user