mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-29 11:02:54 +08:00
support logprob in mtp
This commit is contained in:
@@ -332,7 +332,9 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length);
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob);
|
||||
|
||||
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &stop_flags,
|
||||
@@ -891,6 +893,46 @@ void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetLogits(
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &first_token_logits,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder);
|
||||
|
||||
void SpeculateInsertFirstToken(const paddle::Tensor &token_ids,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &cu_batch_token_offset,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder);
|
||||
|
||||
void SpeculateGetTargetLogits(const paddle::Tensor &target_logits,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &cu_batch_token_offset,
|
||||
const paddle::Tensor &ori_cu_batch_token_offset,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &accept_num);
|
||||
|
||||
void SpeculateSaveOutMmsgTopK(const paddle::Tensor &sampled_token_ids,
|
||||
const paddle::Tensor &logprob_token_ids,
|
||||
const paddle::Tensor &logprob_scores,
|
||||
const paddle::Tensor &logprob_ranks,
|
||||
const paddle::Tensor &token_num_per_batch,
|
||||
const paddle::Tensor &cu_batch_token_offset,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
int mtype,
|
||||
int64_t rank_id);
|
||||
|
||||
void SpeculateGetOutMmsgTopK(const paddle::Tensor &output_tokens,
|
||||
const paddle::Tensor &output_scores,
|
||||
const paddle::Tensor &output_ranks,
|
||||
int real_k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
@@ -1277,4 +1319,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function");
|
||||
|
||||
m.def("save_output", &SaveOutMmsgStatic, "save_output function");
|
||||
|
||||
m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function");
|
||||
|
||||
m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function");
|
||||
|
||||
m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function");
|
||||
|
||||
m.def("speculate_save_output_topk", &SpeculateSaveOutMmsgTopK, "speculate_save_output_topk function");
|
||||
|
||||
m.def("speculate_get_output_topk", &SpeculateGetOutMmsgTopK, "speculate_get_output_topk function");
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ __global__ void RebuildPaddingKernel(T *output_data,
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
T *first_token_out,
|
||||
const T *input_data,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_len_this_time,
|
||||
@@ -55,7 +56,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
const int max_input_length,
|
||||
const int dim_embed,
|
||||
const int64_t output_elem_nums,
|
||||
const int bsz) {
|
||||
const int bsz,
|
||||
const bool enable_logprob) {
|
||||
AlignedVector<T, VecSize> src_vec;
|
||||
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int64_t i = global_idx * VecSize; i < output_elem_nums;
|
||||
@@ -77,6 +79,35 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
|
||||
&src_vec);
|
||||
Store<T, VecSize>(src_vec, &output_data[i]);
|
||||
|
||||
// printf(
|
||||
// "[normal] out_token_id: %d, ori_token_id: %d, input_token_id: %d "
|
||||
// "bias_idx: %d, bid: %d, seq_id: %d\n",
|
||||
// out_token_id,
|
||||
// ori_token_id,
|
||||
// input_token_id,
|
||||
// bias_idx,
|
||||
// bi,
|
||||
// seq_id);
|
||||
|
||||
if (enable_logprob && seq_len_encoder[bi] > 0) {
|
||||
int first_token_seq_id = seq_len_encoder[bi] - 2;
|
||||
const int first_token_id =
|
||||
ori_token_id - cum_offset_bi + first_token_seq_id;
|
||||
// printf(
|
||||
// "[first token] out_token_id: %d, ori_token_id: %d, "
|
||||
// "first_token_id: %d, bias_idx: %d, bid: %d, "
|
||||
// "first_token_seq_id: %d\n",
|
||||
// out_token_id,
|
||||
// ori_token_id,
|
||||
// first_token_id,
|
||||
// bias_idx,
|
||||
// bi,
|
||||
// first_token_seq_id);
|
||||
Load<T, VecSize>(&input_data[first_token_id * dim_embed + bias_idx],
|
||||
&src_vec);
|
||||
Store<T, VecSize>(src_vec, &first_token_out[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +120,9 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -120,6 +153,9 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
0,
|
||||
D,
|
||||
tmp_out.place());
|
||||
// printf("token_num: %d, need_delete_token_num: %d\n",
|
||||
// token_num,
|
||||
// need_delete_token_num);
|
||||
} else {
|
||||
out =
|
||||
paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place());
|
||||
@@ -130,11 +166,20 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
const int grid_size = (pack_num + blocksize - 1) / blocksize;
|
||||
printf("elem_nums: %d\n", elem_nums);
|
||||
|
||||
if (output_padding_offset) {
|
||||
// if (first_token_out.is_initialized()) {
|
||||
// printf("first_token_out is initialized, enable_logprob: %d\n",
|
||||
// enable_logprob);
|
||||
// }
|
||||
RebuildAppendPaddingKernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
first_token_out.is_initialized()
|
||||
? reinterpret_cast<DataType_ *>(const_cast<data_t *>(
|
||||
first_token_out.get_ptr()->data<data_t>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_len_this_time.data<int>(),
|
||||
@@ -144,7 +189,8 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums,
|
||||
bsz);
|
||||
bsz,
|
||||
enable_logprob);
|
||||
} else {
|
||||
RebuildPaddingKernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||
@@ -169,7 +215,9 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
switch (tmp_out.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return rebuild_padding<paddle::DataType::BFLOAT16>(
|
||||
@@ -179,7 +227,9 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length)[0];
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return rebuild_padding<paddle::DataType::FLOAT16>(
|
||||
@@ -189,7 +239,9 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length)[0];
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return rebuild_padding<paddle::DataType::FLOAT32>(
|
||||
@@ -199,7 +251,9 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length)[0];
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
@@ -217,14 +271,18 @@ std::vector<paddle::Tensor> RebuildPadding(
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
return {RebuildPaddingFunc(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length)};
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
@@ -260,9 +318,10 @@ PD_BUILD_STATIC_OP(rebuild_padding)
|
||||
"seq_len_this_time",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_encoder",
|
||||
paddle::Optional("output_padding_offset")})
|
||||
paddle::Optional("output_padding_offset"),
|
||||
paddle::Optional("first_token_out")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"max_input_length: int"})
|
||||
.Attrs({"max_input_length: int", "enable_logprob: bool"})
|
||||
.SetKernelFn(PD_KERNEL(RebuildPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define K 20
|
||||
#define MAX_DRAFT_TOKEN_NUM 6
|
||||
#define SPECULATE_GET_WITH_OUTPUT_DEBUG
|
||||
|
||||
struct batch_msgdata {
|
||||
int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)];
|
||||
float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)];
|
||||
int ranks[MAX_DRAFT_TOKEN_NUM];
|
||||
};
|
||||
|
||||
struct msgdata {
|
||||
int meta[3 + MAX_BSZ]; // stop_flag, mtype, bsz, batch_token_nums
|
||||
batch_msgdata mtext[MAX_BSZ];
|
||||
};
|
||||
|
||||
void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
|
||||
const paddle::Tensor& output_scores,
|
||||
const paddle::Tensor& output_ranks,
|
||||
int real_k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
static struct msgdata msg_rcv;
|
||||
int msg_queue_id = 1;
|
||||
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
|
||||
std::cout << "get_output_key: " << key << std::endl;
|
||||
std::cout << "get_output msgid: " << msgid << std::endl;
|
||||
#endif
|
||||
|
||||
int64_t* output_tokens_data =
|
||||
const_cast<int64_t*>(output_tokens.data<int64_t>());
|
||||
float* output_scores_data = const_cast<float*>(output_scores.data<float>());
|
||||
int64_t* output_ranks_data =
|
||||
const_cast<int64_t*>(output_ranks.data<int64_t>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(
|
||||
msgid,
|
||||
&msg_rcv,
|
||||
(3 + MAX_BSZ) * sizeof(int) +
|
||||
MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) +
|
||||
(MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) +
|
||||
MAX_DRAFT_TOKEN_NUM * sizeof(int)),
|
||||
0,
|
||||
IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(
|
||||
msgid,
|
||||
&msg_rcv,
|
||||
(3 + MAX_BSZ) * sizeof(int) +
|
||||
MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) +
|
||||
(MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) +
|
||||
MAX_DRAFT_TOKEN_NUM * sizeof(int)),
|
||||
0,
|
||||
0);
|
||||
}
|
||||
if (ret == -1) {
|
||||
// read none
|
||||
output_tokens_data[0] = -2; // stop_flag
|
||||
output_tokens_data[1] = msg_rcv.meta[1]; // mtype, Target: 3, Draft: 4
|
||||
output_tokens_data[2] = 0; // bsz
|
||||
return;
|
||||
}
|
||||
|
||||
int bsz = msg_rcv.meta[1];
|
||||
output_tokens_data[0] = msg_rcv.meta[0];
|
||||
output_tokens_data[1] = msg_rcv.meta[1];
|
||||
output_tokens_data[2] = msg_rcv.meta[2];
|
||||
|
||||
int output_tokens_offset = 3 + MAX_BSZ;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
int cur_token_num = msg_rcv.meta[3 + i];
|
||||
output_tokens_data[3 + i] = cur_token_num; // batch_token_nums
|
||||
|
||||
auto* cur_output_token = output_tokens_data + output_tokens_offset +
|
||||
i * (MAX_DRAFT_TOKEN_NUM * (K + 1));
|
||||
auto* cur_output_score =
|
||||
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1));
|
||||
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
|
||||
for (int j = 0; j < cur_token_num; j++) {
|
||||
for (int k = 0; k < real_k + 1; k++) {
|
||||
cur_output_token[j * (K + 1) + k] =
|
||||
cur_batch_msg_rcv->tokens[k];
|
||||
cur_output_score[j * (K + 1) + k] =
|
||||
cur_batch_msg_rcv->scores[k];
|
||||
}
|
||||
output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] =
|
||||
cur_batch_msg_rcv->ranks[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_output_topk)
|
||||
.Inputs({"output_tokens", "output_scores", "output_ranks"})
|
||||
.Attrs({"real_k: int", "rank_id: int64_t", "wait_flag: bool"})
|
||||
.Outputs({"output_tokens_out", "output_scores_out", "output_ranks_out"})
|
||||
.SetInplaceMap({{"output_tokens", "output_tokens_out"},
|
||||
{"output_scores", "output_scores_out"},
|
||||
{"output_ranks", "output_ranks_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetOutMmsgTopK));
|
||||
285
custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu
Normal file
285
custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu
Normal file
@@ -0,0 +1,285 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void get_token_num_per_batch_kernel(int* batch_token_num,
|
||||
int* total_token_num,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int real_bsz) {
|
||||
int bid = threadIdx.x;
|
||||
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int token_num_now = 0;
|
||||
if (bid < real_bsz) {
|
||||
token_num_now = seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid];
|
||||
batch_token_num[bid] = token_num_now;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
int token_num_sum = BlockReduce(temp_storage).Sum(token_num_now);
|
||||
if (bid == 0) {
|
||||
total_token_num[0] = token_num_sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <int VecSize>
|
||||
__global__ void speculate_get_logits_kernel(float* draft_logits,
|
||||
const float* logits,
|
||||
const float* first_token_logits,
|
||||
const int* cu_seqlens_q,
|
||||
const int* cu_batch_token_offset,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int vocab_size,
|
||||
const int real_bsz) {
|
||||
AlignedVector<float, VecSize> src_vec;
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
if (bid < real_bsz) {
|
||||
auto* draft_logits_now =
|
||||
draft_logits + cu_batch_token_offset[bid] * vocab_size;
|
||||
auto* logits_now = logits + cu_seqlens_q[bid] * vocab_size;
|
||||
for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) {
|
||||
if (seq_lens_encoder[bid] > 0) {
|
||||
Load<float, VecSize>(&first_token_logits[bid * vocab_size + i],
|
||||
&src_vec);
|
||||
Store<float, VecSize>(src_vec, &draft_logits_now[i]);
|
||||
|
||||
Load<float, VecSize>(&logits_now[i], &src_vec);
|
||||
Store<float, VecSize>(src_vec,
|
||||
&draft_logits_now[vocab_size + i]);
|
||||
} else {
|
||||
for (int j = 0; j < seq_lens_this_time[bid]; j++) {
|
||||
Load<float, VecSize>(&logits_now[j * vocab_size + i],
|
||||
&src_vec);
|
||||
Store<float, VecSize>(
|
||||
src_vec, &draft_logits_now[j * vocab_size + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetLogits(
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& first_token_logits,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int vocab_size = logits.shape()[1];
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
auto total_token_num = paddle::full(
|
||||
{1}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
auto batch_token_num = paddle::full(
|
||||
{real_bsz}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
|
||||
constexpr int THREADBLOCK_SIZE = 512;
|
||||
get_token_num_per_batch_kernel<THREADBLOCK_SIZE>
|
||||
<<<1, THREADBLOCK_SIZE, 0, cu_stream>>>(batch_token_num.data<int>(),
|
||||
total_token_num.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
real_bsz);
|
||||
|
||||
auto total_token_num_cpu =
|
||||
total_token_num.copy_to(paddle::CPUPlace(), true);
|
||||
|
||||
auto draft_logits =
|
||||
paddle::empty({total_token_num_cpu.data<int>()[0], vocab_size},
|
||||
paddle::DataType::FLOAT32,
|
||||
seq_lens_this_time.place());
|
||||
auto cu_batch_token_offset = paddle::full(
|
||||
{real_bsz + 1}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
|
||||
void* temp_storage = nullptr;
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceScan::InclusiveSum(temp_storage,
|
||||
temp_storage_bytes,
|
||||
batch_token_num.data<int>(),
|
||||
&cu_batch_token_offset.data<int>()[1],
|
||||
real_bsz,
|
||||
cu_stream);
|
||||
cudaMalloc(&temp_storage, temp_storage_bytes);
|
||||
cub::DeviceScan::InclusiveSum(temp_storage,
|
||||
temp_storage_bytes,
|
||||
batch_token_num.data<int>(),
|
||||
&cu_batch_token_offset.data<int>()[1],
|
||||
real_bsz,
|
||||
cu_stream);
|
||||
|
||||
constexpr int PackSize = VEC_16B / sizeof(float);
|
||||
dim3 grid_dim(real_bsz);
|
||||
dim3 block_dim(128);
|
||||
speculate_get_logits_kernel<PackSize>
|
||||
<<<grid_dim, block_dim, 0, cu_stream>>>(
|
||||
const_cast<float*>(draft_logits.data<float>()),
|
||||
logits.data<float>(),
|
||||
first_token_logits.data<float>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_batch_token_offset.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
vocab_size,
|
||||
real_bsz);
|
||||
|
||||
return {draft_logits, batch_token_num, cu_batch_token_offset};
|
||||
}
|
||||
|
||||
__global__ void speculate_insert_first_token_kernel(
|
||||
int64_t* token_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int64_t* next_tokens,
|
||||
const int* cu_seqlens_q,
|
||||
const int* cu_batch_token_offset,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int max_draft_tokens,
|
||||
const int real_bsz) {
|
||||
const int bid = threadIdx.x;
|
||||
|
||||
auto* token_ids_now = token_ids + cu_batch_token_offset[bid];
|
||||
auto* accept_tokens_now = accept_tokens + bid * max_draft_tokens;
|
||||
auto* next_tokens_now = next_tokens + cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
token_ids_now[0] = accept_tokens_now[0];
|
||||
token_ids_now[1] = next_tokens_now[0];
|
||||
} else {
|
||||
for (int i = 0; i < seq_lens_this_time[bid]; i++) {
|
||||
token_ids_now[i] = next_tokens_now[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateInsertFirstToken(const paddle::Tensor& token_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& cu_batch_token_offset,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int max_draft_tokens = accept_tokens.shape()[1];
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
speculate_insert_first_token_kernel<<<1, real_bsz, 0, cu_stream>>>(
|
||||
const_cast<int64_t*>(token_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_batch_token_offset.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_draft_tokens,
|
||||
real_bsz);
|
||||
}
|
||||
|
||||
template <int VecSize>
|
||||
__global__ void speculate_get_target_logits_kernel(
|
||||
float* target_logtis,
|
||||
const float* logits,
|
||||
const int* cu_batch_token_offset,
|
||||
const int* ori_cu_batch_token_offset,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* accept_num,
|
||||
const int vocab_size,
|
||||
const int real_bsz) {
|
||||
AlignedVector<float, VecSize> src_vec;
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
if (bid < real_bsz) {
|
||||
auto* target_logtis_now =
|
||||
target_logtis + cu_batch_token_offset[bid] * vocab_size;
|
||||
auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size;
|
||||
for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) {
|
||||
if (seq_lens_encoder[bid] > 0) {
|
||||
Load<float, VecSize>(&logits_now[i], &src_vec);
|
||||
Store<float, VecSize>(src_vec, &target_logtis_now[i]);
|
||||
} else {
|
||||
for (int j = 0; j < accept_num[bid]; j++) {
|
||||
Load<float, VecSize>(&logits_now[j * vocab_size + i],
|
||||
&src_vec);
|
||||
Store<float, VecSize>(
|
||||
src_vec, &target_logtis_now[j * vocab_size + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateGetTargetLogits(const paddle::Tensor& target_logits,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& cu_batch_token_offset,
|
||||
const paddle::Tensor& ori_cu_batch_token_offset,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& accept_num) {
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int vocab_size = logits.shape()[1];
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
constexpr int PackSize = VEC_16B / sizeof(float);
|
||||
dim3 grid_dim(real_bsz);
|
||||
dim3 block_dim(128);
|
||||
speculate_get_target_logits_kernel<PackSize>
|
||||
<<<grid_dim, block_dim, 0, cu_stream>>>(
|
||||
const_cast<float*>(target_logits.data<float>()),
|
||||
logits.data<float>(),
|
||||
cu_batch_token_offset.data<int>(),
|
||||
ori_cu_batch_token_offset.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
accept_num.data<int>(),
|
||||
vocab_size,
|
||||
real_bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_logits)
|
||||
.Inputs({"logits",
|
||||
"first_token_logits",
|
||||
"cu_seqlens_q",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder"})
|
||||
.Outputs({"draft_logits", "batch_token_num", "cu_batch_token_offset"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetLogits));
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_insert_first_token)
|
||||
.Inputs({"token_ids",
|
||||
"accept_tokens",
|
||||
"next_tokens",
|
||||
"cu_seqlens_q",
|
||||
"cu_batch_token_offset",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder"})
|
||||
.Outputs({"token_ids_out"})
|
||||
.SetInplaceMap({{"token_ids", "token_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateInsertFirstToken));
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_target_logits)
|
||||
.Inputs({"target_logits",
|
||||
"logits",
|
||||
"cu_batch_token_offset",
|
||||
"ori_cu_batch_token_offset",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"accept_num"})
|
||||
.Outputs({"target_logits_out"})
|
||||
.SetInplaceMap({{"target_logits", "target_logits_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetTargetLogits));
|
||||
@@ -0,0 +1,209 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define K 20
|
||||
#define MAX_DRAFT_TOKEN_NUM 6
|
||||
#define SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
|
||||
struct batch_msgdata {
|
||||
int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)];
|
||||
float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)];
|
||||
int ranks[MAX_DRAFT_TOKEN_NUM];
|
||||
};
|
||||
|
||||
struct msgdata {
|
||||
int meta[3 + MAX_BSZ]; // stop_flag, mtype, bsz, batch_token_nums
|
||||
batch_msgdata mtext[MAX_BSZ];
|
||||
};
|
||||
|
||||
void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
const paddle::Tensor& logprob_token_ids,
|
||||
const paddle::Tensor& logprob_scores,
|
||||
const paddle::Tensor& logprob_ranks,
|
||||
const paddle::Tensor& token_num_per_batch,
|
||||
const paddle::Tensor& cu_batch_token_offset,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int mtype, // Target: 3, Draft: 4
|
||||
int64_t rank_id) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
auto sampled_token_ids_cpu =
|
||||
sampled_token_ids.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_token_ids_cpu =
|
||||
logprob_token_ids.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false);
|
||||
auto token_num_per_batch_cpu =
|
||||
token_num_per_batch.copy_to(paddle::CPUPlace(), false);
|
||||
auto cu_batch_token_offset_cpu =
|
||||
cu_batch_token_offset.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data<int64_t>();
|
||||
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
|
||||
float* logprob_scores_data = logprob_scores_cpu.data<float>();
|
||||
int64_t* logprob_ranks_data = logprob_ranks_cpu.data<int64_t>();
|
||||
int* token_num_per_batch_data = token_num_per_batch_cpu.data<int>();
|
||||
int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data<int>();
|
||||
|
||||
static struct msgdata msg_sed;
|
||||
int msg_queue_id = 1;
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
std::string inference_msg_queue_id_env_str(
|
||||
inference_msg_queue_id_env_p);
|
||||
int inference_msg_queue_id_from_env =
|
||||
std::stoi(inference_msg_queue_id_env_str);
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
|
||||
<< inference_msg_queue_id_from_env << std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
int inference_msg_id_from_env = 1;
|
||||
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
if (inference_msg_id_from_env < 0) {
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be negative, please use other "
|
||||
"number.");
|
||||
}
|
||||
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
|
||||
<< std::endl;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout
|
||||
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
static key_t key = ftok("/dev/shm", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "save_output_key: " << key << std::endl;
|
||||
std::cout << "save msgid: " << msgid << std::endl;
|
||||
#endif
|
||||
|
||||
msg_sed.meta[0] = not_need_stop.data<bool>()[0]
|
||||
? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
msg_sed.meta[1] = mtype;
|
||||
int bsz = token_num_per_batch.shape()[0];
|
||||
msg_sed.meta[2] = bsz;
|
||||
int max_num_logprobs = logprob_token_ids.shape()[1];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
int cur_token_num = token_num_per_batch_data[i];
|
||||
msg_sed.meta[3 + i] = cur_token_num;
|
||||
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
|
||||
int token_offset = cu_batch_token_offset_data[i];
|
||||
for (int j = 0; j < cur_token_num; j++) {
|
||||
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)];
|
||||
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
|
||||
std::cout << "token_offset: " << token_offset << std::endl;
|
||||
for (int k = 0; k < K + 1; k++) {
|
||||
if (k == 0) {
|
||||
cur_tokens[k] =
|
||||
(int)sampled_token_ids_data[token_offset + j];
|
||||
cur_scores[k] =
|
||||
logprob_scores_data[(token_offset + j) * (K + 1) + k];
|
||||
} else if (k < max_num_logprobs) {
|
||||
cur_tokens[k] = (int)
|
||||
logprob_token_ids_data[(token_offset + j) * (K + 1) +
|
||||
k];
|
||||
cur_scores[k] =
|
||||
logprob_scores_data[(token_offset + j) * (K + 1) + k];
|
||||
} else {
|
||||
cur_tokens[k] = -1;
|
||||
cur_scores[k] = 0.0;
|
||||
}
|
||||
}
|
||||
cur_batch_msg_sed->ranks[j] =
|
||||
(int)logprob_ranks_data[token_offset + j];
|
||||
}
|
||||
}
|
||||
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "msg data: " << std::endl;
|
||||
std::cout << "stop_flag: " << msg_sed.meta[0]
|
||||
<< ", mtype: " << msg_sed.meta[1] << ", bsz: " << msg_sed.meta[2]
|
||||
<< std::endl;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
int cur_token_num = msg_sed.meta[3 + i];
|
||||
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
|
||||
std::cout << "batch " << i << " token_num: " << cur_token_num
|
||||
<< std::endl;
|
||||
for (int j = 0; j < cur_token_num; j++) {
|
||||
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)];
|
||||
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
|
||||
std::cout << "tokens: ";
|
||||
for (int k = 0; k < K + 1; k++) {
|
||||
std::cout << cur_tokens[k] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "scores: ";
|
||||
for (int k = 0; k < K + 1; k++) {
|
||||
std::cout << cur_scores[k] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl;
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
if ((msgsnd(msgid,
|
||||
&msg_sed,
|
||||
(3 + MAX_BSZ) * sizeof(int) +
|
||||
MAX_BSZ * ((MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(int) +
|
||||
(MAX_DRAFT_TOKEN_NUM * (K + 1)) * sizeof(float) +
|
||||
MAX_DRAFT_TOKEN_NUM * sizeof(int)),
|
||||
0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_save_output_topk)
|
||||
.Inputs({
|
||||
"sampled_token_ids",
|
||||
"logprob_token_ids",
|
||||
"logprob_scores",
|
||||
"logprob_ranks",
|
||||
"token_num_per_batch",
|
||||
"cu_batch_token_offset",
|
||||
"not_need_stop",
|
||||
})
|
||||
.Attrs({"mtype: int", "rank_id: int64_t"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK));
|
||||
@@ -403,8 +403,8 @@ class EngineArgs:
|
||||
if self.dynamic_load_weight:
|
||||
self.enable_prefix_caching = False
|
||||
if self.enable_logprob:
|
||||
if self.speculative_config is not None:
|
||||
raise NotImplementedError("Logprob does not support speculation_config.")
|
||||
# if self.speculative_config is not None:
|
||||
# raise NotImplementedError("Logprob does not support speculation_config.")
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if self.splitwise_role != "mixed":
|
||||
|
||||
@@ -36,6 +36,10 @@ from fastdeploy.model_executor.layers.sample.ops import (
|
||||
min_p_sampling,
|
||||
top_k_top_p_sampling,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
speculate_get_target_logits,
|
||||
speculate_insert_first_token,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||
|
||||
@@ -286,8 +290,11 @@ class Sampler(nn.Layer):
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
print(f"[Sampler] logprobs: {logprobs}")
|
||||
print(f"[Sampler] token_logprobs: {token_logprobs}")
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
print(f"[Sampler] token_ranks: {token_ranks}")
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
@@ -356,6 +363,7 @@ class Sampler(nn.Layer):
|
||||
sampled_token_ids=next_tokens,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
print(f"[Sampler] sampler_output: {sampler_output}")
|
||||
|
||||
return sampler_output
|
||||
|
||||
@@ -375,6 +383,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
||||
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
||||
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
|
||||
self.speculative_tokens_num = fd_config.speculative_config.num_speculative_tokens
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
@@ -389,6 +398,104 @@ class SpeculativeSampler(nn.Layer):
|
||||
"""apply logits processor to sampler"""
|
||||
pass
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> paddle.Tensor:
|
||||
"""compute logprobs"""
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
last_logits = logits
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
print(f"[SpeculativeSampler][compute] seq_lens_this_time: {share_inputs['seq_lens_this_time']}")
|
||||
print(f"[SpeculativeSampler][compute] seq_lens_encoder: {share_inputs['seq_lens_encoder']}")
|
||||
batch_token_num = share_inputs["batch_token_num"]
|
||||
|
||||
print(f"[SpeculativeSampler][compute] batch_token_num: {batch_token_num}")
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
if temp_scaled_logprobs is not None:
|
||||
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
|
||||
temperature = sampling_metadata.temperature[:real_bsz]
|
||||
real_bsz_temp_scaled = (
|
||||
real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool")
|
||||
)
|
||||
temperature = temperature.squeeze(1).repeat_interleave(batch_token_num)
|
||||
temp_temperature = paddle.where(
|
||||
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
|
||||
).unsqueeze(1)
|
||||
last_logits = last_logits / temp_temperature
|
||||
|
||||
last_logprobs = F.log_softmax(last_logits, axis=-1)
|
||||
top_p_logprob = None
|
||||
top_p_token_mask = None
|
||||
|
||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||
real_token_top_p = (
|
||||
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
|
||||
)
|
||||
top_p_normalized_logprobs = (
|
||||
top_p_normalized_logprobs[:real_bsz]
|
||||
.astype("int32")
|
||||
.squeeze(1)
|
||||
.repeat_interleave(batch_token_num)
|
||||
.astype("bool")
|
||||
.unsqueeze(1)
|
||||
)
|
||||
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
|
||||
if top_p_token_mask.any():
|
||||
probs = F.softmax(last_logits, axis=-1)
|
||||
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
|
||||
top_p_logprob = paddle.log(probs)
|
||||
if top_p_logprob is not None:
|
||||
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
|
||||
return last_logprobs
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: paddle.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: paddle.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == paddle.int64
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
print(f"[SpeculativeSampler] logprobs: {logprobs}")
|
||||
print(f"[SpeculativeSampler] token_logprobs: {token_logprobs}")
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
print(f"[SpeculativeSampler] token_ranks: {token_ranks}")
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
|
||||
indices = paddle.concat([token_ids, topk_indices], axis=1)
|
||||
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
|
||||
else:
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
@@ -427,6 +534,9 @@ class SpeculativeSampler(nn.Layer):
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
print(f"[SpeculativeSampler] verify_tokens: {verify_tokens}")
|
||||
print(f"[SpeculativeSampler] actual_candidate_len: {actual_candidate_len}")
|
||||
|
||||
speculate_verify(
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["accept_num"],
|
||||
@@ -452,8 +562,64 @@ class SpeculativeSampler(nn.Layer):
|
||||
True, # enable_topp
|
||||
self.speculative_benchmark_mode,
|
||||
)
|
||||
print(f"[SpeculativeSampler] accept_num: {share_inputs['accept_num']}")
|
||||
print(f"[SpeculativeSampler] accept_tokens: {share_inputs['accept_tokens']}")
|
||||
|
||||
return None
|
||||
print(f"[SpeculativeSampler] logits: {logits}")
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
batch_token_num = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
|
||||
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
|
||||
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
|
||||
).squeeze(1)
|
||||
share_inputs["batch_token_num"] = batch_token_num
|
||||
print(f"[SpeculativeSampler] batch_token_num: {share_inputs['batch_token_num']}")
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat(
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"])]
|
||||
).astype("int32")
|
||||
print(f"[SpeculativeSampler] ori_cu_batch_token_offset: {ori_cu_batch_token_offset}")
|
||||
print(f"[SpeculativeSampler] cu_batch_token_offset: {cu_batch_token_offset}")
|
||||
target_logtis = paddle.empty([share_inputs["accept_num"].sum(), logits.shape[1]], dtype=logits.dtype)
|
||||
speculate_get_target_logits(
|
||||
target_logtis,
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["accept_num"],
|
||||
)
|
||||
print(f"[SpeculativeSampler] target_logtis: {target_logtis}")
|
||||
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
|
||||
print(f"[SpeculativeSampler] raw_logprobs: {raw_logprobs}")
|
||||
|
||||
sampler_output = None
|
||||
if num_logprobs is not None:
|
||||
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
token_ids = paddle.concat(
|
||||
[
|
||||
share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]]
|
||||
for i in range(share_inputs["accept_num"].shape[0])
|
||||
]
|
||||
)
|
||||
print(f"[SpeculativeSampler] token_ids: {token_ids}")
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=batch_token_num,
|
||||
)
|
||||
|
||||
print(f"[SpeculativeSampler] sampler_output: {sampler_output}")
|
||||
|
||||
return sampler_output
|
||||
|
||||
|
||||
class MTPSampler(nn.Layer):
|
||||
@@ -466,6 +632,7 @@ class MTPSampler(nn.Layer):
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.speculative_tokens_num = fd_config.speculative_config.num_speculative_tokens
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
@@ -480,6 +647,115 @@ class MTPSampler(nn.Layer):
|
||||
"""apply logits processor to sampler"""
|
||||
pass
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> paddle.Tensor:
|
||||
"""compute logprobs"""
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
last_logits = logits
|
||||
# print(f"[MTPSampler][compute] real_bsz: {real_bsz}")
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
if temp_scaled_logprobs is not None:
|
||||
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
|
||||
temperature = sampling_metadata.temperature[:real_bsz]
|
||||
real_bsz_temp_scaled = (
|
||||
real_bsz_temp_scaled.astype("int32")
|
||||
.squeeze(1)
|
||||
.repeat_interleave(share_inputs["batch_token_num"])
|
||||
.astype("bool")
|
||||
)
|
||||
temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"])
|
||||
# print(f"[MTPSampler][compute] real_bsz_temp_scaled: {real_bsz_temp_scaled}")
|
||||
# print(f"[MTPSampler][compute] temperature: {temperature}")
|
||||
temp_temperature = paddle.where(
|
||||
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
|
||||
).unsqueeze(1)
|
||||
# print(f"[MTPSampler][compute] temp_temperature: {temp_temperature}")
|
||||
last_logits = last_logits / temp_temperature
|
||||
# print(f"[MTPSampler][compute] last_logits: {last_logits}")
|
||||
|
||||
last_logprobs = F.log_softmax(last_logits, axis=-1)
|
||||
# print(f"[MTPSampler][compute] last_logits: {last_logits}")
|
||||
top_p_logprob = None
|
||||
top_p_token_mask = None
|
||||
|
||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||
real_token_top_p = (
|
||||
sampling_metadata.top_p[:real_bsz]
|
||||
.squeeze(1)
|
||||
.repeat_interleave(share_inputs["batch_token_num"])
|
||||
.unsqueeze(1)
|
||||
)
|
||||
# print(f"[MTPSampler][compute] real_token_top_p: {real_token_top_p}")
|
||||
top_p_normalized_logprobs = (
|
||||
top_p_normalized_logprobs[:real_bsz]
|
||||
.astype("int32")
|
||||
.squeeze(1)
|
||||
.repeat_interleave(share_inputs["batch_token_num"])
|
||||
.astype("bool")
|
||||
.unsqueeze(1)
|
||||
)
|
||||
# print(f"[MTPSampler][compute] top_p_normalized_logprobs: {top_p_normalized_logprobs}")
|
||||
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
|
||||
# print(f"[MTPSampler][compute] top_p_token_mask: {top_p_token_mask}")
|
||||
|
||||
if top_p_token_mask.any():
|
||||
probs = F.softmax(last_logits, axis=-1)
|
||||
# print(f"[MTPSampler][compute] probs: {probs}")
|
||||
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
|
||||
# print(f"[MTPSampler][compute] probs: {probs}")
|
||||
top_p_logprob = paddle.log(probs)
|
||||
# print(f"[MTPSampler][compute] top_p_logprob: {top_p_logprob}")
|
||||
if top_p_logprob is not None:
|
||||
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
|
||||
return last_logprobs
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: paddle.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: paddle.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == paddle.int64
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
|
||||
indices = paddle.concat([token_ids, topk_indices], axis=1)
|
||||
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
|
||||
else:
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
@@ -488,6 +764,11 @@ class MTPSampler(nn.Layer):
|
||||
share_inputs: List[paddle.Tensor],
|
||||
) -> paddle.Tensor:
|
||||
""" """
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"], sampling_metadata)
|
||||
print(f"[MTPSampler] raw_logprobs: {raw_logprobs}")
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
@@ -509,4 +790,30 @@ class MTPSampler(nn.Layer):
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
|
||||
)
|
||||
return next_tokens
|
||||
|
||||
sampler_output = None
|
||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
token_ids = paddle.empty(share_inputs["batch_token_num"].sum(), dtype="int64")
|
||||
speculate_insert_first_token(
|
||||
token_ids,
|
||||
share_inputs["accept_tokens"],
|
||||
next_tokens,
|
||||
share_inputs["cu_seqlens_q"],
|
||||
share_inputs["cu_batch_token_offset"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
)
|
||||
print(f"[MTPSampler] token_ids: {token_ids}")
|
||||
print(f"[MTPSampler] total_token_num: {share_inputs['batch_token_num'].sum()}")
|
||||
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=share_inputs["batch_token_num"],
|
||||
)
|
||||
print(f"[MTPSampler] sampler_output: {sampler_output}")
|
||||
print(f"[MTPSampler] next_tokens: {next_tokens}")
|
||||
|
||||
return next_tokens, sampler_output
|
||||
|
||||
@@ -529,6 +529,8 @@ def rebuild_padding(
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
output_padding_offset: Optional[paddle.Tensor] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
first_token_out: Optional[paddle.Tensor] = None,
|
||||
enable_logprob: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -544,7 +546,9 @@ def rebuild_padding(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob,
|
||||
)
|
||||
elif current_platform.is_dcu():
|
||||
from fastdeploy.model_executor.ops.gpu import rebuild_padding
|
||||
|
||||
@@ -60,6 +60,15 @@ class TokenProcessor:
|
||||
self.use_logprobs = self.cfg.model_config.enable_logprob
|
||||
|
||||
if self.speculative_decoding:
|
||||
if self.use_logprobs:
|
||||
self.output_tokens = paddle.full(
|
||||
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64"
|
||||
)
|
||||
self.output_scores = paddle.full(
|
||||
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32"
|
||||
)
|
||||
self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64")
|
||||
else:
|
||||
self.output_tokens = paddle.full(
|
||||
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
|
||||
fill_value=2,
|
||||
@@ -149,6 +158,7 @@ class TokenProcessor:
|
||||
get_output_ep,
|
||||
get_output_topk,
|
||||
speculate_get_output,
|
||||
speculate_get_output_topk,
|
||||
)
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
|
||||
@@ -156,6 +166,21 @@ class TokenProcessor:
|
||||
try:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
if self.use_logprobs:
|
||||
speculate_get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
print(f"[TokenProcessor] output_tokens: {self.output_tokens}")
|
||||
print(f"[TokenProcessor] output_scores: {self.output_scores}")
|
||||
print(f"[TokenProcessor] output_ranks: {self.output_ranks}")
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
else:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
|
||||
@@ -41,6 +41,8 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
speculate_get_logits,
|
||||
speculate_save_output_topk,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
|
||||
@@ -62,6 +64,7 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs = target_model_inputs
|
||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||
self.enable_logprob = self.model_config.enable_logprob
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = "mixed"
|
||||
@@ -354,6 +357,13 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
)
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
self.model_inputs["temp_scaled_logprobs"] = self.target_model_inputs["temp_scaled_logprobs"]
|
||||
self.model_inputs["top_p_normalized_logprobs"] = self.target_model_inputs["top_p_normalized_logprobs"]
|
||||
self.model_inputs["accept_num"] = self.target_model_inputs["accept_num"]
|
||||
self.model_inputs["accept_tokens"] = self.target_model_inputs["accept_tokens"]
|
||||
self.model_inputs["draft_logits"] = self.target_model_inputs["draft_logits"]
|
||||
max_num_seqs = self.model_inputs["seq_lens_encoder"].shape[0]
|
||||
self.model_inputs["first_token_hidden_states"] = paddle.full([max_num_seqs, self.model_config.hidden_size], -1)
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
|
||||
@@ -616,6 +626,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
print(f"[MTPProposer] ******************** substep: {substep} ********************")
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
(
|
||||
@@ -657,6 +668,10 @@ class MTPProposer(Proposer):
|
||||
min_dec_lens=self.model_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.model_inputs["bad_tokens"],
|
||||
eos_token_ids=self.model_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
|
||||
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
|
||||
share_inputs=self.model_inputs,
|
||||
)
|
||||
|
||||
if self.num_model_steps > 1:
|
||||
@@ -667,7 +682,19 @@ class MTPProposer(Proposer):
|
||||
previous_hidden_states=target_hidden_states,
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
print(f"[MTPProposer] model_output: {model_output}")
|
||||
|
||||
if self.enable_logprob and substep == 0:
|
||||
first_token_hidden_states = paddle.empty(
|
||||
[self.max_num_seqs, self.model_config.hidden_size], dtype=model_output.dtype
|
||||
)
|
||||
|
||||
print(f"[MTPProposer] cu_seqlens_q: {self.model_inputs['cu_seqlens_q']}")
|
||||
print(f"[MTPProposer] seq_lens_this_time: {self.model_inputs['seq_lens_this_time']}")
|
||||
print(f"[MTPProposer] seq_lens_encoder: {self.model_inputs['seq_lens_encoder']}")
|
||||
print(f"[MTPProposer] seq_lens_decoder: {self.model_inputs['seq_lens_decoder']}")
|
||||
print(f"[MTPProposer] output_cum_offsets: {self.model_inputs['output_cum_offsets']}")
|
||||
print(f"[MTPProposer] output_padding_offset: {self.model_inputs['output_padding_offset']}")
|
||||
hidden_states = rebuild_padding(
|
||||
model_output,
|
||||
self.model_inputs["cu_seqlens_q"],
|
||||
@@ -676,18 +703,54 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["output_padding_offset"],
|
||||
self.parallel_config.max_model_len,
|
||||
first_token_hidden_states if substep == 0 else None,
|
||||
self.enable_logprob if substep == 0 else False,
|
||||
)
|
||||
print(f"[MTPProposer] hidden_states: {hidden_states}")
|
||||
print(f"[MTPProposer] first_token_hidden_states: {first_token_hidden_states}")
|
||||
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
if self.enable_logprob and substep == 0:
|
||||
first_token_logits = self.model.compute_logits(first_token_hidden_states)
|
||||
print(f"[MTPProposer] logits: {logits}")
|
||||
print(f"[MTPProposer] first_token_logits: {first_token_logits}")
|
||||
print(f"[MTPProposer] output_padding_offset: {self.model_inputs['output_padding_offset']}")
|
||||
|
||||
sampled_token_ids = self.sampler(
|
||||
draft_logits, batch_token_num, cu_batch_token_offset = speculate_get_logits(
|
||||
logits,
|
||||
first_token_logits,
|
||||
self.model_inputs["cu_seqlens_q"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
)
|
||||
self.model_inputs["draft_logits"] = draft_logits
|
||||
self.model_inputs["batch_token_num"] = batch_token_num
|
||||
self.model_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
print(f"[MTPProposer] draft_logits: {draft_logits}")
|
||||
print(f"[MTPProposer] batch_token_num: {batch_token_num}")
|
||||
print(f"[MTPProposer] cu_batch_token_offset: {cu_batch_token_offset}")
|
||||
|
||||
sampled_token_ids, sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
self.max_model_len,
|
||||
self.model_inputs,
|
||||
)
|
||||
|
||||
if substep == 0:
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
sampler_output.logprobs_tensors.logprob_token_ids,
|
||||
sampler_output.logprobs_tensors.logprobs,
|
||||
sampler_output.logprobs_tensors.selected_token_ranks,
|
||||
batch_token_num,
|
||||
cu_batch_token_offset,
|
||||
self.model_inputs["not_need_stop"],
|
||||
4, # mtype
|
||||
self.local_rank,
|
||||
)
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
||||
|
||||
|
||||
@@ -106,6 +106,7 @@ class SamplerOutput:
|
||||
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
|
||||
sampled_token_ids: paddle.Tensor
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
token_num_per_batch: Optional[paddle.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user