[Feature] support mtp logprob (#4464)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support mtp logprob

* fix unitest
This commit is contained in:
GoldPancake
2025-10-20 15:18:12 +08:00
committed by GitHub
parent 1b9f351d21
commit 47595a2480
14 changed files with 1181 additions and 32 deletions

View File

@@ -348,7 +348,9 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length);
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob);
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &stop_flags,
@@ -910,6 +912,32 @@ void SaveOutMmsgStatic(const paddle::Tensor& x,
int64_t rank_id,
bool save_each_rank);
void SpeculateGetLogits(const paddle::Tensor &draft_logits,
const paddle::Tensor &next_token_num,
const paddle::Tensor &batch_token_num,
const paddle::Tensor &cu_next_token_offset,
const paddle::Tensor &cu_batch_token_offset,
const paddle::Tensor &logits,
const paddle::Tensor &first_token_logits,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder);
void SpeculateInsertFirstToken(const paddle::Tensor &token_ids,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &next_tokens,
const paddle::Tensor &cu_next_token_offset,
const paddle::Tensor &cu_batch_token_offset,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder);
void SpeculateGetTargetLogits(const paddle::Tensor &target_logits,
const paddle::Tensor &logits,
const paddle::Tensor &cu_batch_token_offset,
const paddle::Tensor &ori_cu_batch_token_offset,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &accept_num);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -1291,4 +1319,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function");
m.def("save_output", &SaveOutMmsgStatic, "save_output function");
m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function");
m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function");
m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function");
}

View File

@@ -46,6 +46,7 @@ __global__ void RebuildPaddingKernel(T *output_data,
template <typename T, int VecSize>
__global__ void RebuildAppendPaddingKernel(T *output_data,
T *first_token_out,
const T *input_data,
const int *cu_seqlens_q,
const int *seq_len_this_time,
@@ -55,7 +56,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
const int max_input_length,
const int dim_embed,
const int64_t output_elem_nums,
const int bsz) {
const int bsz,
const bool enable_logprob) {
AlignedVector<T, VecSize> src_vec;
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = global_idx * VecSize; i < output_elem_nums;
@@ -70,13 +72,20 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
const int input_token_id = ori_token_id - cum_offset_bi + seq_id;
const int bias_idx = i % dim_embed;
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
&src_vec);
Store<T, VecSize>(src_vec, &output_data[i]);
if (enable_logprob && seq_len_encoder[bi] > 0) {
const int first_input_token_id = input_token_id - 1;
Load<T, VecSize>(&input_data[first_input_token_id * dim_embed + bias_idx],
&src_vec);
Store<T, VecSize>(src_vec, &first_token_out[bi * dim_embed + bias_idx]);
}
}
}
@@ -89,7 +98,9 @@ std::vector<paddle::Tensor> rebuild_padding(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -134,6 +145,10 @@ std::vector<paddle::Tensor> rebuild_padding(
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
first_token_out.is_initialized()
? reinterpret_cast<DataType_ *>(const_cast<data_t *>(
first_token_out.get_ptr()->data<data_t>()))
: nullptr,
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
cu_seqlens_q.data<int>(),
seq_len_this_time.data<int>(),
@@ -143,7 +158,8 @@ std::vector<paddle::Tensor> rebuild_padding(
max_input_length,
dim_embed,
elem_nums,
bsz);
bsz,
enable_logprob);
} else {
RebuildPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
@@ -168,7 +184,9 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
switch (tmp_out.type()) {
case paddle::DataType::BFLOAT16: {
return rebuild_padding<paddle::DataType::BFLOAT16>(
@@ -178,7 +196,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
case paddle::DataType::FLOAT16: {
return rebuild_padding<paddle::DataType::FLOAT16>(
@@ -188,7 +208,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
case paddle::DataType::FLOAT32: {
return rebuild_padding<paddle::DataType::FLOAT32>(
@@ -198,7 +220,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
default: {
PD_THROW(
@@ -216,14 +240,18 @@ std::vector<paddle::Tensor> RebuildPadding(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
return {RebuildPaddingFunc(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)};
first_token_out,
max_input_length,
enable_logprob)};
}
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
@@ -259,9 +287,10 @@ PD_BUILD_STATIC_OP(rebuild_padding)
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",
paddle::Optional("output_padding_offset")})
paddle::Optional("output_padding_offset"),
paddle::Optional("first_token_out")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
.Attrs({"max_input_length: int", "enable_logprob: bool"})
.SetKernelFn(PD_KERNEL(RebuildPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));

View File

@@ -0,0 +1,161 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
#define K 20
#define MAX_DRAFT_TOKEN_NUM 6
struct batch_msgdata {
int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)];
float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)];
int ranks[MAX_DRAFT_TOKEN_NUM];
};
struct msgdata {
long mtype;
int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums
batch_msgdata mtext[MAX_BSZ];
};
void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
const paddle::Tensor& output_scores,
const paddle::Tensor& output_ranks,
int real_k,
int64_t rank_id,
bool wait_flag) {
struct msgdata msg_rcv;
int msg_queue_id = 1;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
std::cout << "get_output_key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
#endif
int64_t* output_tokens_data =
const_cast<int64_t*>(output_tokens.data<int64_t>());
float* output_scores_data = const_cast<float*>(output_scores.data<float>());
int64_t* output_ranks_data =
const_cast<int64_t*>(output_ranks.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(
msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, 0);
}
if (ret == -1) {
// read none
output_tokens_data[0] = -2; // stop_flag
output_tokens_data[1] = 0; // message_flag, Target: 3, Draft: 4
output_tokens_data[2] = 0; // bsz
return;
}
int bsz = msg_rcv.meta[2];
output_tokens_data[0] = (int64_t)msg_rcv.meta[0];
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];
int output_tokens_offset = 3 + MAX_BSZ;
for (int i = 0; i < bsz; i++) {
int cur_token_num = msg_rcv.meta[3 + i];
output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums
auto* cur_output_token = output_tokens_data + output_tokens_offset +
i * (MAX_DRAFT_TOKEN_NUM * (K + 1));
auto* cur_output_score =
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1));
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
for (int j = 0; j < cur_token_num; j++) {
for (int k = 0; k < real_k + 1; k++) {
cur_output_token[j * (K + 1) + k] =
(int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k];
cur_output_score[j * (K + 1) + k] =
cur_batch_msg_rcv->scores[j * (K + 1) + k];
}
output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] =
(int64_t)cur_batch_msg_rcv->ranks[j];
}
}
#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG
std::cout << "msg data: " << std::endl;
std::cout << "stop_flag: " << output_tokens_data[0]
<< ", message_flag: " << output_tokens_data[1]
<< ", bsz: " << output_tokens_data[2] << std::endl;
for (int i = 0; i < output_tokens_data[2]; i++) {
int cur_token_num = output_tokens_data[3 + i];
std::cout << "batch " << i << " token_num: " << cur_token_num
<< std::endl;
for (int j = 0; j < cur_token_num; j++) {
std::cout << "tokens: ";
for (int k = 0; k < K + 1; k++) {
std::cout
<< output_tokens_data[output_tokens_offset +
i * MAX_DRAFT_TOKEN_NUM * (K + 1) +
j * (K + 1) + k]
<< " ";
}
std::cout << std::endl;
std::cout << "scores: ";
for (int k = 0; k < K + 1; k++) {
std::cout
<< output_scores_data[i * MAX_DRAFT_TOKEN_NUM * (K + 1) +
j * (K + 1) + k]
<< " ";
}
std::cout << std::endl;
std::cout << "ranks: "
<< output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j]
<< std::endl;
}
}
std::cout << std::endl;
#endif
return;
}
PD_BUILD_STATIC_OP(speculate_get_output_topk)
.Inputs({"output_tokens", "output_scores", "output_ranks"})
.Attrs({"real_k: int", "rank_id: int64_t", "wait_flag: bool"})
.Outputs({"output_tokens_out", "output_scores_out", "output_ranks_out"})
.SetInplaceMap({{"output_tokens", "output_tokens_out"},
{"output_scores", "output_scores_out"},
{"output_ranks", "output_ranks_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetOutMmsgTopK));

View File

@@ -0,0 +1,290 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
__global__ void get_token_num_per_batch_kernel(int* next_token_num,
int* batch_token_num,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz) {
int bid = threadIdx.x;
if (bid < real_bsz) {
next_token_num[bid] =
seq_lens_encoder[bid] > 0 ? 1 : seq_lens_this_time[bid];
batch_token_num[bid] =
seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid];
}
}
template <int VecSize>
__global__ void speculate_get_logits_kernel(float* draft_logits,
const float* logits,
const float* first_token_logits,
const int* cu_next_token_offset,
const int* cu_batch_token_offset,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int vocab_size,
const int real_bsz) {
AlignedVector<float, VecSize> src_vec;
const int bid = blockIdx.x;
const int tid = threadIdx.x;
if (bid < real_bsz) {
auto* draft_logits_now =
draft_logits + cu_batch_token_offset[bid] * vocab_size;
auto* logits_now = logits + cu_next_token_offset[bid] * vocab_size;
for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) {
if (seq_lens_encoder[bid] > 0) {
Load<float, VecSize>(&first_token_logits[bid * vocab_size + i],
&src_vec);
Store<float, VecSize>(src_vec, &draft_logits_now[i]);
Load<float, VecSize>(&logits_now[i], &src_vec);
Store<float, VecSize>(src_vec,
&draft_logits_now[vocab_size + i]);
} else {
for (int j = 0; j < seq_lens_this_time[bid]; j++) {
Load<float, VecSize>(&logits_now[j * vocab_size + i],
&src_vec);
Store<float, VecSize>(
src_vec, &draft_logits_now[j * vocab_size + i]);
}
}
}
}
}
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
const paddle::Tensor& next_token_num,
const paddle::Tensor& batch_token_num,
const paddle::Tensor& cu_next_token_offset,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& logits,
const paddle::Tensor& first_token_logits,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder) {
auto cu_stream = seq_lens_this_time.stream();
const int vocab_size = logits.shape()[1];
const int real_bsz = seq_lens_this_time.shape()[0];
get_token_num_per_batch_kernel<<<1, 512, 0, cu_stream>>>(
const_cast<int*>(next_token_num.data<int>()),
const_cast<int*>(batch_token_num.data<int>()),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
real_bsz);
void* temp_storage1 = nullptr;
size_t temp_storage_bytes1 = 0;
cub::DeviceScan::InclusiveSum(
temp_storage1,
temp_storage_bytes1,
batch_token_num.data<int>(),
const_cast<int*>(&cu_batch_token_offset.data<int>()[1]),
real_bsz,
cu_stream);
cudaMalloc(&temp_storage1, temp_storage_bytes1);
cub::DeviceScan::InclusiveSum(
temp_storage1,
temp_storage_bytes1,
batch_token_num.data<int>(),
const_cast<int*>(&cu_batch_token_offset.data<int>()[1]),
real_bsz,
cu_stream);
void* temp_storage2 = nullptr;
size_t temp_storage_bytes2 = 0;
cub::DeviceScan::InclusiveSum(
temp_storage2,
temp_storage_bytes2,
next_token_num.data<int>(),
const_cast<int*>(&cu_next_token_offset.data<int>()[1]),
real_bsz,
cu_stream);
cudaMalloc(&temp_storage2, temp_storage_bytes2);
cub::DeviceScan::InclusiveSum(
temp_storage2,
temp_storage_bytes2,
next_token_num.data<int>(),
const_cast<int*>(&cu_next_token_offset.data<int>()[1]),
real_bsz,
cu_stream);
constexpr int PackSize = VEC_16B / sizeof(float);
dim3 grid_dim(real_bsz);
dim3 block_dim(128);
speculate_get_logits_kernel<PackSize>
<<<grid_dim, block_dim, 0, cu_stream>>>(
const_cast<float*>(draft_logits.data<float>()),
logits.data<float>(),
first_token_logits.data<float>(),
cu_next_token_offset.data<int>(),
cu_batch_token_offset.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
vocab_size,
real_bsz);
}
__global__ void speculate_insert_first_token_kernel(
int64_t* token_ids,
const int64_t* accept_tokens,
const int64_t* next_tokens,
const int* cu_next_token_offset,
const int* cu_batch_token_offset,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int max_draft_tokens,
const int real_bsz) {
const int bid = threadIdx.x;
auto* token_ids_now = token_ids + cu_batch_token_offset[bid];
auto* accept_tokens_now = accept_tokens + bid * max_draft_tokens;
auto* next_tokens_now = next_tokens + cu_next_token_offset[bid];
if (seq_lens_encoder[bid] != 0) {
token_ids_now[0] = accept_tokens_now[0];
token_ids_now[1] = next_tokens_now[0];
} else {
for (int i = 0; i < seq_lens_this_time[bid]; i++) {
token_ids_now[i] = next_tokens_now[i];
}
}
}
void SpeculateInsertFirstToken(const paddle::Tensor& token_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& next_tokens,
const paddle::Tensor& cu_next_token_offset,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder) {
auto cu_stream = seq_lens_this_time.stream();
const int max_draft_tokens = accept_tokens.shape()[1];
const int real_bsz = seq_lens_this_time.shape()[0];
speculate_insert_first_token_kernel<<<1, real_bsz, 0, cu_stream>>>(
const_cast<int64_t*>(token_ids.data<int64_t>()),
accept_tokens.data<int64_t>(),
next_tokens.data<int64_t>(),
cu_next_token_offset.data<int>(),
cu_batch_token_offset.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
max_draft_tokens,
real_bsz);
}
template <int VecSize>
__global__ void speculate_get_target_logits_kernel(
float* target_logtis,
const float* logits,
const int* cu_batch_token_offset,
const int* ori_cu_batch_token_offset,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* accept_num,
const int vocab_size,
const int real_bsz) {
AlignedVector<float, VecSize> src_vec;
const int bid = blockIdx.x;
const int tid = threadIdx.x;
if (bid < real_bsz) {
auto* target_logtis_now =
target_logtis + cu_batch_token_offset[bid] * vocab_size;
auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size;
for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) {
if (seq_lens_encoder[bid] > 0) {
Load<float, VecSize>(&logits_now[i], &src_vec);
Store<float, VecSize>(src_vec, &target_logtis_now[i]);
} else {
for (int j = 0; j < accept_num[bid]; j++) {
Load<float, VecSize>(&logits_now[j * vocab_size + i],
&src_vec);
Store<float, VecSize>(
src_vec, &target_logtis_now[j * vocab_size + i]);
}
}
}
}
}
void SpeculateGetTargetLogits(const paddle::Tensor& target_logits,
const paddle::Tensor& logits,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& ori_cu_batch_token_offset,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num) {
auto cu_stream = seq_lens_this_time.stream();
const int vocab_size = logits.shape()[1];
const int real_bsz = seq_lens_this_time.shape()[0];
constexpr int PackSize = VEC_16B / sizeof(float);
dim3 grid_dim(real_bsz);
dim3 block_dim(128);
speculate_get_target_logits_kernel<PackSize>
<<<grid_dim, block_dim, 0, cu_stream>>>(
const_cast<float*>(target_logits.data<float>()),
logits.data<float>(),
cu_batch_token_offset.data<int>(),
ori_cu_batch_token_offset.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
accept_num.data<int>(),
vocab_size,
real_bsz);
}
PD_BUILD_STATIC_OP(speculate_get_logits)
.Inputs({"draft_logits",
"next_token_num",
"batch_token_num",
"cu_next_token_offset",
"cu_batch_token_offset",
"logits",
"first_token_logits",
"seq_lens_this_time",
"seq_lens_encoder"})
.Outputs({"draft_logits_out",
"batch_token_num_out",
"cu_batch_token_offset_out"})
.SetInplaceMap({{"draft_logits", "draft_logits_out"},
{"batch_token_num", "batch_token_num_out"},
{"cu_batch_token_offset", "cu_batch_token_offset_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetLogits));
PD_BUILD_STATIC_OP(speculate_insert_first_token)
.Inputs({"token_ids",
"accept_tokens",
"next_tokens",
"cu_next_token_offset",
"cu_batch_token_offset",
"seq_lens_this_time",
"seq_lens_encoder"})
.Outputs({"token_ids_out"})
.SetInplaceMap({{"token_ids", "token_ids_out"}})
.SetKernelFn(PD_KERNEL(SpeculateInsertFirstToken));
PD_BUILD_STATIC_OP(speculate_get_target_logits)
.Inputs({"target_logits",
"logits",
"cu_batch_token_offset",
"ori_cu_batch_token_offset",
"seq_lens_this_time",
"seq_lens_encoder",
"accept_num"})
.Outputs({"target_logits_out"})
.SetInplaceMap({{"target_logits", "target_logits_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetTargetLogits));

View File

@@ -0,0 +1,202 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
#define K 20
#define MAX_DRAFT_TOKEN_NUM 6
struct batch_msgdata {
int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)];
float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)];
int ranks[MAX_DRAFT_TOKEN_NUM];
};
struct msgdata {
long mtype;
int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums
batch_msgdata mtext[MAX_BSZ];
};
void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& logprob_token_ids,
const paddle::Tensor& logprob_scores,
const paddle::Tensor& logprob_ranks,
const paddle::Tensor& token_num_per_batch,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& not_need_stop,
int message_flag, // Target: 3, Draft: 4
int64_t rank_id) {
if (rank_id > 0) {
return;
}
auto sampled_token_ids_cpu =
sampled_token_ids.copy_to(paddle::CPUPlace(), false);
auto logprob_token_ids_cpu =
logprob_token_ids.copy_to(paddle::CPUPlace(), false);
auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false);
auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false);
auto token_num_per_batch_cpu =
token_num_per_batch.copy_to(paddle::CPUPlace(), false);
auto cu_batch_token_offset_cpu =
cu_batch_token_offset.copy_to(paddle::CPUPlace(), false);
int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data<int64_t>();
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
float* logprob_scores_data = logprob_scores_cpu.data<float>();
int64_t* logprob_ranks_data = logprob_ranks_cpu.data<int64_t>();
int* token_num_per_batch_data = token_num_per_batch_cpu.data<int>();
int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data<int>();
static struct msgdata msg_sed;
int msg_queue_id = 1;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
msg_queue_id = inference_msg_queue_id_from_env;
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
} else {
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
<< std::endl;
#endif
}
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is perserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
#endif
msg_sed.mtype = 1;
msg_sed.meta[0] = not_need_stop.data<bool>()[0]
? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
int bsz = token_num_per_batch.shape()[0];
msg_sed.meta[2] = bsz;
int max_num_logprobs = logprob_token_ids.shape()[1];
for (int i = 0; i < bsz; i++) {
int cur_token_num = token_num_per_batch_data[i];
msg_sed.meta[3 + i] = cur_token_num;
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
for (int k = 0; k < K + 1; k++) {
if (k == 0) {
cur_tokens[k] =
(int)sampled_token_ids_data[token_offset + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (K + 1) + k];
} else if (k < max_num_logprobs) {
cur_tokens[k] = (int)
logprob_token_ids_data[(token_offset + j) * (K + 1) +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (K + 1) + k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
}
}
cur_batch_msg_sed->ranks[j] =
(int)logprob_ranks_data[token_offset + j];
}
}
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: " << std::endl;
std::cout << "stop_flag: " << msg_sed.meta[0]
<< ", message_flag: " << msg_sed.meta[1]
<< ", bsz: " << msg_sed.meta[2] << std::endl;
for (int i = 0; i < bsz; i++) {
int cur_token_num = msg_sed.meta[3 + i];
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
std::cout << "batch " << i << " token_num: " << cur_token_num
<< std::endl;
for (int j = 0; j < cur_token_num; j++) {
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
std::cout << "tokens: ";
for (int k = 0; k < K + 1; k++) {
std::cout << cur_tokens[k] << " ";
}
std::cout << std::endl;
std::cout << "scores: ";
for (int k = 0; k < K + 1; k++) {
std::cout << cur_scores[k] << " ";
}
std::cout << std::endl;
std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl;
}
}
std::cout << std::endl;
#endif
if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) {
printf("full msg buffer\n");
}
}
PD_BUILD_STATIC_OP(speculate_save_output_topk)
.Inputs({
"sampled_token_ids",
"logprob_token_ids",
"logprob_scores",
"logprob_ranks",
"token_num_per_batch",
"cu_batch_token_offset",
"not_need_stop",
})
.Attrs({"message_flag: int", "rank_id: int64_t"})
.SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK));

View File

@@ -416,8 +416,6 @@ class EngineArgs:
# if self.dynamic_load_weight:
# self.enable_prefix_caching = False
if self.enable_logprob:
if self.speculative_config is not None:
raise NotImplementedError("Logprob does not support speculation_config.")
if not current_platform.is_cuda():
raise NotImplementedError("Only CUDA platform supports logprob.")
if self.speculative_config is not None:

View File

@@ -18,6 +18,10 @@ from .apply_penalty_multi_scores import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
)
from .speculate_logprob_utils import (
speculate_get_target_logits,
speculate_insert_first_token,
)
from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling
__all__ = [
@@ -25,4 +29,6 @@ __all__ = [
"apply_speculative_penalty_multi_scores",
"top_k_top_p_sampling",
"min_p_sampling",
"speculate_get_target_logits",
"speculate_insert_first_token",
]

View File

@@ -0,0 +1,72 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.platforms import current_platform
def speculate_get_target_logits(
target_logits: paddle.Tensor,
logits: paddle.Tensor,
cu_batch_token_offset: paddle.Tensor,
ori_cu_batch_token_offset: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
accept_num: paddle.Tensor,
):
"""
speculate_get_target_logits
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import speculate_get_target_logits
speculate_get_target_logits(
target_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
accept_num,
)
else:
raise NotImplementedError
def speculate_insert_first_token(
token_ids: paddle.Tensor,
accept_tokens: paddle.Tensor,
next_tokens: paddle.Tensor,
cu_next_token_offset: paddle.Tensor,
cu_batch_token_offset: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import speculate_insert_first_token
speculate_insert_first_token(
token_ids,
accept_tokens,
next_tokens,
cu_next_token_offset,
cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
)
else:
raise NotImplementedError

View File

@@ -32,6 +32,8 @@ from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
min_p_sampling,
speculate_get_target_logits,
speculate_insert_first_token,
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
@@ -455,6 +457,98 @@ class SpeculativeSampler(nn.Layer):
"""apply logits processor to sampler"""
pass
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
) -> paddle.Tensor:
"""compute logprobs"""
share_inputs = sampling_metadata.share_inputs
last_logits = logits
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = share_inputs["batch_token_num"][:real_bsz]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
real_bsz_temp_scaled = (
real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool")
)
temperature = temperature.squeeze(1).repeat_interleave(batch_token_num)
temp_temperature = paddle.where(
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
).unsqueeze(1)
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
real_token_top_p = (
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
)
top_p_normalized_logprobs = (
top_p_normalized_logprobs[:real_bsz]
.astype("int32")
.squeeze(1)
.repeat_interleave(batch_token_num)
.astype("bool")
.unsqueeze(1)
)
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
if top_p_token_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
logprobs: paddle.Tensor,
num_logprobs: int,
token_ids: paddle.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
token_ids = token_ids.unsqueeze(1)
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
if num_logprobs >= 1:
# Find the topK values.
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
def forward_cuda(
self,
logits: paddle.Tensor,
@@ -521,7 +615,56 @@ class SpeculativeSampler(nn.Layer):
accept_all_drafts,
)
return None
num_logprobs = sampling_metadata.max_num_logprobs
batch_token_num = None
if num_logprobs is not None:
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
).squeeze(1)
share_inputs["batch_token_num"] = batch_token_num
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
target_logtis = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
)
speculate_get_target_logits(
target_logtis,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["accept_num"],
)
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
logprobs_tensors = None
token_ids = share_inputs["accept_tokens"]
if num_logprobs is not None:
token_ids = paddle.concat(
[
share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]]
for i in range(share_inputs["accept_num"][:real_bsz].shape[0])
]
)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=batch_token_num,
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)
return sampler_output
class MTPSampler(nn.Layer):
@@ -556,6 +699,103 @@ class MTPSampler(nn.Layer):
"""post process after running"""
pass
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
) -> paddle.Tensor:
"""compute logprobs"""
share_inputs = sampling_metadata.share_inputs
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
last_logits = logits
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
real_bsz_temp_scaled = (
real_bsz_temp_scaled.astype("int32")
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.astype("bool")
)
temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
temp_temperature = paddle.where(
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
).unsqueeze(1)
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
real_token_top_p = (
sampling_metadata.top_p[:real_bsz]
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.unsqueeze(1)
)
top_p_normalized_logprobs = (
top_p_normalized_logprobs[:real_bsz]
.astype("int32")
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.astype("bool")
.unsqueeze(1)
)
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
if top_p_token_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
logprobs: paddle.Tensor,
num_logprobs: int,
token_ids: paddle.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
token_ids = token_ids.unsqueeze(1)
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
if num_logprobs >= 1:
# Find the topK values.
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
def forward_cuda(
self,
logits: paddle.Tensor,
@@ -564,6 +804,12 @@ class MTPSampler(nn.Layer):
share_inputs: List[paddle.Tensor],
) -> paddle.Tensor:
""" """
num_logprobs = sampling_metadata.max_num_logprobs
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
if num_logprobs is not None and share_inputs["substep"] == 0:
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"][:real_token_num, :], sampling_metadata)
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,
@@ -585,4 +831,27 @@ class MTPSampler(nn.Layer):
_, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
)
return next_tokens
token_ids = None
logprobs_tensors = None
if num_logprobs is not None and share_inputs["substep"] == 0:
token_ids = paddle.empty(real_token_num, dtype="int64")
speculate_insert_first_token(
token_ids,
share_inputs["accept_tokens"],
next_tokens,
share_inputs["cu_next_token_offset"],
share_inputs["cu_batch_token_offset"],
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=share_inputs["batch_token_num"][:real_bsz],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)
return next_tokens, sampler_output

View File

@@ -68,6 +68,7 @@ else:
speculate_get_padding_offset,
speculate_get_seq_lens_output,
speculate_save_output,
speculate_save_output_topk,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_system_cache,
@@ -334,7 +335,10 @@ def post_process_normal(
def post_process_specualate(
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
sampler_output: SamplerOutput,
model_output: ModelOutputData,
save_each_rank: bool = False,
skip_save_output: bool = False,
):
""""""
speculate_update(
@@ -352,16 +356,29 @@ def post_process_specualate(
)
if not skip_save_output:
speculate_save_output(
model_output.accept_tokens,
model_output.accept_num,
model_output.not_need_stop,
model_output.seq_lens_decoder,
model_output.prompt_lens,
model_output.mp_rank,
save_each_rank,
envs.ENABLE_V1_KVCACHE_SCHEDULER,
)
if sampler_output.logprobs_tensors is None:
speculate_save_output(
model_output.accept_tokens,
model_output.accept_num,
model_output.not_need_stop,
model_output.seq_lens_decoder,
model_output.prompt_lens,
model_output.mp_rank,
save_each_rank,
envs.ENABLE_V1_KVCACHE_SCHEDULER,
)
else:
speculate_save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
sampler_output.token_num_per_batch,
sampler_output.cu_batch_token_offset,
model_output.not_need_stop,
3, # mtype
model_output.mp_rank,
)
# Update pre_ids through accept tokens
@@ -389,7 +406,7 @@ def post_process(
) -> None:
"""Post-processing steps after completing a single token generation."""
if speculative_decoding:
post_process_specualate(model_output, save_each_rank, skip_save_output)
post_process_specualate(sampler_output, model_output, save_each_rank, skip_save_output)
else:
post_process_normal(
sampler_output,
@@ -597,6 +614,8 @@ def rebuild_padding(
seq_lens_encoder: paddle.Tensor,
output_padding_offset: Optional[paddle.Tensor] = None,
max_input_length: Optional[int] = None,
first_token_out: Optional[paddle.Tensor] = None,
enable_logprob: Optional[bool] = False,
):
"""
Args:
@@ -612,7 +631,9 @@ def rebuild_padding(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
first_token_out,
max_input_length,
enable_logprob,
)
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import rebuild_padding

View File

@@ -44,6 +44,8 @@ from fastdeploy.model_executor.ops.gpu import (
mtp_save_first_token,
mtp_step_paddle,
share_external_data,
speculate_get_logits,
speculate_save_output_topk,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
@@ -72,6 +74,7 @@ class MTPProposer(Proposer):
self.target_model_inputs = target_model_inputs
self.mtp_strategy = self.speculative_config.mtp_strategy
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
self.enable_logprob = self.model_config.enable_logprob
# [mixed, prefill, decoder]
self.role = "mixed"
@@ -405,6 +408,22 @@ class MTPProposer(Proposer):
self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
)
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
self.model_inputs["temp_scaled_logprobs"] = self.target_model_inputs["temp_scaled_logprobs"]
self.model_inputs["top_p_normalized_logprobs"] = self.target_model_inputs["top_p_normalized_logprobs"]
self.model_inputs["accept_num"] = self.target_model_inputs["accept_num"]
self.model_inputs["accept_tokens"] = self.target_model_inputs["accept_tokens"]
self.model_inputs["draft_logits"] = self.target_model_inputs["draft_logits"]
self.model_inputs["first_token_hidden_states"] = paddle.full(
[self.max_num_seqs, self.model_config.hidden_size], -1
)
self.model_inputs["batch_token_num"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
self.model_inputs["next_token_num"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
self.model_inputs["cu_batch_token_offset"] = paddle.full_like(
self.target_model_inputs["cu_batch_token_offset"], fill_value=0, dtype="int32"
)
self.model_inputs["cu_next_token_offset"] = paddle.full(
shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32"
)
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
@@ -734,6 +753,10 @@ class MTPProposer(Proposer):
min_dec_lens=self.model_inputs["min_dec_len"],
bad_words_token_ids=self.model_inputs["bad_tokens"],
eos_token_ids=self.model_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
share_inputs=self.model_inputs,
)
if self.num_model_steps > 1:
@@ -754,18 +777,48 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"],
self.model_inputs["output_padding_offset"],
self.model_config.max_model_len,
self.model_inputs["first_token_hidden_states"],
self.enable_logprob if substep == 0 else False,
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
if self.enable_logprob and substep == 0:
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
sampled_token_ids = self.sampler(
speculate_get_logits(
self.model_inputs["draft_logits"],
self.model_inputs["next_token_num"],
self.model_inputs["batch_token_num"],
self.model_inputs["cu_next_token_offset"],
self.model_inputs["cu_batch_token_offset"],
logits,
first_token_logits,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
)
sampled_token_ids, sampler_output = self.sampler(
logits,
self.sampling_metadata,
self.max_model_len,
self.model_inputs,
)
if substep == 0 and sampler_output.logprobs_tensors is not None:
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
speculate_save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
self.model_inputs["batch_token_num"][:real_bsz],
self.model_inputs["cu_batch_token_offset"][:real_bsz],
self.model_inputs["not_need_stop"],
4, # mtype
self.local_rank,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(sampled_token_ids, 0)

View File

@@ -1007,6 +1007,15 @@ class GPUModelRunner(ModelRunnerBase):
dtype="int64",
)
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# For MTP Logprob
self.share_inputs["draft_logits"] = paddle.full(
[max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size],
-1,
dtype="float32",
)
self.share_inputs["cu_batch_token_offset"] = paddle.full(
shape=[max_num_seqs + 1], fill_value=0, dtype="int32"
)
if self.enable_mm:
head_dim = self.model_config.head_dim
@@ -1869,13 +1878,12 @@ class GPUModelRunner(ModelRunnerBase):
)
else:
self.sampler(
sampler_output = self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
)
sampler_output = None
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"],

View File

@@ -106,6 +106,8 @@ class SamplerOutput:
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids: paddle.Tensor
logprobs_tensors: Optional[LogprobsTensors]
token_num_per_batch: Optional[paddle.Tensor] = None
cu_batch_token_offset: Optional[paddle.Tensor] = None
@dataclass

View File

@@ -143,7 +143,9 @@ class TestRebuildPadding(unittest.TestCase):
seq_lens_decoder,
seq_lens_encoder,
None,
None,
max_input_length,
False,
)
np.testing.assert_allclose(out_no_offset.numpy(), out_no_offset_ref)
@@ -191,7 +193,9 @@ class TestRebuildPadding(unittest.TestCase):
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
None,
max_input_length,
False,
)
np.testing.assert_allclose(out_with_offset.numpy(), out_with_offset_ref)