mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] add speculate_get_logits (#5497)
* [XPU] add speculate_step_system_cache * [XPU] add speculate_step_system_cache * [XPU] add speculate_get_logits * delete context * add ptr check --------- Co-authored-by: cmcamdy <1027740945@qq.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
80
custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc
Normal file
80
custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include <stdio.h>
|
||||
#include "paddle/common/flags.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "xpu/internal/infra_op.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
|
||||
const paddle::Tensor& next_token_num,
|
||||
const paddle::Tensor& batch_token_num,
|
||||
const paddle::Tensor& cu_next_token_offset,
|
||||
const paddle::Tensor& cu_batch_token_offset,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& first_token_logits,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
baidu::xpu::api::Context* ctx =
|
||||
static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
|
||||
if (draft_logits.is_cpu()) {
|
||||
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
|
||||
}
|
||||
const int vocab_size = logits.shape()[1];
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
baidu::xpu::api::plugin::speculate_get_logits(
|
||||
ctx,
|
||||
const_cast<float*>(draft_logits.data<float>()),
|
||||
const_cast<int*>(next_token_num.data<int>()),
|
||||
const_cast<int*>(batch_token_num.data<int>()),
|
||||
const_cast<int*>(cu_next_token_offset.data<int>()),
|
||||
const_cast<int*>(cu_batch_token_offset.data<int>()),
|
||||
logits.data<float>(),
|
||||
first_token_logits.data<float>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
real_bsz,
|
||||
vocab_size);
|
||||
if (draft_logits.is_cpu()) {
|
||||
delete ctx;
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_logits)
|
||||
.Inputs({"draft_logits",
|
||||
"next_token_num",
|
||||
"batch_token_num",
|
||||
"cu_next_token_offset",
|
||||
"cu_batch_token_offset",
|
||||
"logits",
|
||||
"first_token_logits",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder"})
|
||||
.Outputs({"draft_logits_out",
|
||||
"batch_token_num_out",
|
||||
"cu_batch_token_offset_out"})
|
||||
.SetInplaceMap({{"draft_logits", "draft_logits_out"},
|
||||
{"batch_token_num", "batch_token_num_out"},
|
||||
{"cu_batch_token_offset", "cu_batch_token_offset_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetLogits));
|
||||
@@ -470,6 +470,16 @@ void SpeculateStepPaddle(
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
|
||||
const paddle::Tensor& next_token_num,
|
||||
const paddle::Tensor& batch_token_num,
|
||||
const paddle::Tensor& cu_next_token_offset,
|
||||
const paddle::Tensor& cu_batch_token_offset,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& first_token_logits,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder);
|
||||
|
||||
void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
@@ -1174,6 +1184,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("max_draft_tokens"),
|
||||
"Step paddle function");
|
||||
|
||||
m.def("speculate_get_logits",
|
||||
&SpeculateGetLogits,
|
||||
py::arg("draft_logits"),
|
||||
py::arg("next_token_num"),
|
||||
py::arg("batch_token_num"),
|
||||
py::arg("cu_next_token_offset"),
|
||||
py::arg("cu_batch_token_offset"),
|
||||
py::arg("logits"),
|
||||
py::arg("first_token_logits"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
"speculate get logits function");
|
||||
|
||||
m.def("text_image_gather_scatter",
|
||||
&TextImageGatherScatter,
|
||||
py::arg("input"),
|
||||
|
||||
@@ -600,6 +600,19 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
|
||||
T* output,
|
||||
int dim_embed,
|
||||
int elem_cnt);
|
||||
|
||||
DLL_EXPORT int speculate_get_logits(Context* ctx,
|
||||
float* draft_logits,
|
||||
int* next_token_num,
|
||||
int* batch_token_num,
|
||||
int* cu_next_token_offset,
|
||||
int* cu_batch_token_offset,
|
||||
const float* logits,
|
||||
const float* first_token_logits,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int real_bsz,
|
||||
const int vocab_size);
|
||||
/*--------------------------------------- MTP end
|
||||
* --------------------------------------------*/
|
||||
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
#include "xpu/kernel/cluster.h"
|
||||
#include "xpu/kernel/cluster_partition.h"
|
||||
#include "xpu/kernel/cluster_primitive.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__device__ void prefix_sum(__shared_ptr__ int* sm_seq_lens_encoder,
|
||||
__shared_ptr__ int* sm_seq_lens_this_time,
|
||||
__shared_ptr__ int* sm_batch_token_num,
|
||||
__shared_ptr__ int* sm_cu_batch_token_offset,
|
||||
__shared_ptr__ int* sm_cu_next_token_offset,
|
||||
__global_ptr__ int* batch_token_num,
|
||||
__global_ptr__ int* cu_batch_token_offset,
|
||||
__global_ptr__ const int* seq_lens_this_time,
|
||||
__global_ptr__ const int* seq_lens_encoder,
|
||||
const int real_bsz) {
|
||||
int cid = core_id();
|
||||
int clus_id = cluster_id();
|
||||
|
||||
if (clus_id < real_bsz && cid == 0) {
|
||||
GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int));
|
||||
GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int));
|
||||
int next_token_num_previous = 0;
|
||||
for (int bid = 0; bid < real_bsz; bid++) {
|
||||
sm_batch_token_num[bid] =
|
||||
sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid];
|
||||
if (bid == 0) {
|
||||
sm_cu_batch_token_offset[bid] = 0;
|
||||
sm_cu_next_token_offset[bid] = 0;
|
||||
} else {
|
||||
sm_cu_batch_token_offset[bid] =
|
||||
sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1];
|
||||
sm_cu_next_token_offset[bid] =
|
||||
sm_cu_next_token_offset[bid - 1] + next_token_num_previous;
|
||||
}
|
||||
next_token_num_previous =
|
||||
sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid];
|
||||
}
|
||||
mfence_sm();
|
||||
if (clus_id == 0) {
|
||||
SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int));
|
||||
SM2GM_ASYNC(sm_cu_batch_token_offset,
|
||||
cu_batch_token_offset,
|
||||
real_bsz * sizeof(int));
|
||||
}
|
||||
}
|
||||
mfence_sm();
|
||||
sync_all();
|
||||
}
|
||||
__global__ void speculate_get_logits(float* draft_logits,
|
||||
int* next_token_num,
|
||||
int* batch_token_num,
|
||||
int* cu_next_token_offset,
|
||||
int* cu_batch_token_offset,
|
||||
const float* logits,
|
||||
const float* first_token_logits,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int real_bsz,
|
||||
const int vocab_size) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clus_id = cluster_id();
|
||||
int nclusters = cluster_num();
|
||||
|
||||
int lm_size = 2 * 1024;
|
||||
int lm_buf_len = lm_size / sizeof(float);
|
||||
float first_token_logits_now_lm[lm_buf_len];
|
||||
float logits_now_lm[lm_buf_len];
|
||||
|
||||
const int sm_size = 256 * 1024;
|
||||
__shared__ char sm[sm_size];
|
||||
int sm_max_buf_len = 256 * 1024 / sizeof(int);
|
||||
sm_max_buf_len /= 5;
|
||||
__shared_ptr__ int* sm_seq_lens_encoder = (__shared_ptr__ int*)sm;
|
||||
__shared_ptr__ int* sm_seq_lens_this_time =
|
||||
sm_seq_lens_encoder + sm_max_buf_len;
|
||||
__shared_ptr__ int* sm_batch_token_num =
|
||||
sm_seq_lens_this_time + sm_max_buf_len;
|
||||
__shared_ptr__ int* sm_cu_batch_token_offset =
|
||||
sm_batch_token_num + sm_max_buf_len;
|
||||
__shared_ptr__ int* sm_cu_next_token_offset =
|
||||
sm_cu_batch_token_offset + sm_max_buf_len;
|
||||
|
||||
prefix_sum(sm_seq_lens_encoder,
|
||||
sm_seq_lens_this_time,
|
||||
sm_batch_token_num,
|
||||
sm_cu_batch_token_offset,
|
||||
sm_cu_next_token_offset,
|
||||
batch_token_num,
|
||||
cu_batch_token_offset,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
real_bsz);
|
||||
|
||||
for (int bid = clus_id; bid < real_bsz; bid += nclusters) {
|
||||
auto* draft_logits_now =
|
||||
draft_logits + sm_cu_batch_token_offset[bid] * vocab_size;
|
||||
auto* logits_now = logits + sm_cu_next_token_offset[bid] * vocab_size;
|
||||
auto* first_token_logits_now = first_token_logits + bid * vocab_size;
|
||||
|
||||
for (int i = cid * lm_buf_len; i < vocab_size; i += ncores * lm_buf_len) {
|
||||
int read_len = min(lm_buf_len, vocab_size - i);
|
||||
if (sm_seq_lens_encoder[bid] > 0) {
|
||||
GM2LM_ASYNC(first_token_logits_now + i,
|
||||
first_token_logits_now_lm,
|
||||
read_len * sizeof(float));
|
||||
GM2LM(logits_now + i, logits_now_lm, read_len * sizeof(float));
|
||||
LM2GM_ASYNC(first_token_logits_now_lm,
|
||||
draft_logits_now + i,
|
||||
read_len * sizeof(float));
|
||||
LM2GM(logits_now_lm,
|
||||
draft_logits_now + vocab_size + i,
|
||||
read_len * sizeof(float));
|
||||
} else {
|
||||
for (int j = 0; j < sm_seq_lens_this_time[bid]; j++) {
|
||||
GM2LM(logits_now + j * vocab_size + i,
|
||||
logits_now_lm,
|
||||
read_len * sizeof(float));
|
||||
LM2GM(logits_now_lm,
|
||||
draft_logits_now + j * vocab_size + i,
|
||||
read_len * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
@@ -0,0 +1,184 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void speculate_get_logits(
|
||||
float* draft_logits,
|
||||
int* next_token_num,
|
||||
int* batch_token_num,
|
||||
int* cu_next_token_offset,
|
||||
int* cu_batch_token_offset,
|
||||
const float* logits,
|
||||
const float* first_token_logits,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int real_bsz,
|
||||
const int vocab_size);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(float* draft_logits,
|
||||
int* next_token_num,
|
||||
int* batch_token_num,
|
||||
int* cu_next_token_offset,
|
||||
int* cu_batch_token_offset,
|
||||
const float* logits,
|
||||
const float* first_token_logits,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int real_bsz,
|
||||
const int vocab_size) {
|
||||
int batch_token_num_sum = 0;
|
||||
int next_token_num_sum = 0;
|
||||
for (int bid = 0; bid < real_bsz; bid++) {
|
||||
// prefix sum
|
||||
cu_batch_token_offset[bid] = batch_token_num_sum;
|
||||
cu_next_token_offset[bid] = next_token_num_sum;
|
||||
|
||||
batch_token_num[bid] =
|
||||
seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid];
|
||||
next_token_num[bid] =
|
||||
seq_lens_encoder[bid] > 0 ? 1 : seq_lens_this_time[bid];
|
||||
|
||||
batch_token_num_sum += batch_token_num[bid];
|
||||
next_token_num_sum += next_token_num[bid];
|
||||
|
||||
auto* draft_logits_now =
|
||||
draft_logits + cu_batch_token_offset[bid] * vocab_size;
|
||||
auto* logits_now = logits + cu_next_token_offset[bid] * vocab_size;
|
||||
auto* first_token_logits_now = first_token_logits + bid * vocab_size;
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
if (seq_lens_encoder[bid] > 0) {
|
||||
draft_logits_now[i] = first_token_logits_now[i];
|
||||
draft_logits_now[vocab_size + i] = logits_now[i];
|
||||
} else {
|
||||
for (int j = 0; j < seq_lens_this_time[bid]; j++) {
|
||||
draft_logits_now[j * vocab_size + i] = logits_now[j * vocab_size + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
float* draft_logits,
|
||||
int* next_token_num,
|
||||
int* batch_token_num,
|
||||
int* cu_next_token_offset,
|
||||
int* cu_batch_token_offset,
|
||||
const float* logits,
|
||||
const float* first_token_logits,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int real_bsz,
|
||||
const int vocab_size) {
|
||||
xpu3::plugin::speculate_get_logits<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
draft_logits,
|
||||
next_token_num,
|
||||
batch_token_num,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
logits,
|
||||
first_token_logits,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
real_bsz,
|
||||
vocab_size);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int speculate_get_logits(Context* ctx,
|
||||
float* draft_logits,
|
||||
int* next_token_num,
|
||||
int* batch_token_num,
|
||||
int* cu_next_token_offset,
|
||||
int* cu_batch_token_offset,
|
||||
const float* logits,
|
||||
const float* first_token_logits,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int real_bsz,
|
||||
const int vocab_size) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_logits", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
draft_logits,
|
||||
next_token_num,
|
||||
batch_token_num,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
logits);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
first_token_logits,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
real_bsz,
|
||||
vocab_size);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, next_token_num);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, batch_token_num);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_next_token_offset);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_batch_token_offset);
|
||||
WRAPPER_CHECK_PTR(ctx, float, real_bsz* vocab_size, first_token_logits);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder);
|
||||
WRAPPER_ASSERT_LE(ctx, real_bsz, 256 * 1024 / sizeof(int) / 5);
|
||||
WRAPPER_ASSERT_GT(ctx, vocab_size, 0);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(draft_logits,
|
||||
next_token_num,
|
||||
batch_token_num,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
logits,
|
||||
first_token_logits,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
real_bsz,
|
||||
vocab_size);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
draft_logits,
|
||||
next_token_num,
|
||||
batch_token_num,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
logits,
|
||||
first_token_logits,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
real_bsz,
|
||||
vocab_size);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
172
custom_ops/xpu_ops/test/test_speculate_get_logits.py
Normal file
172
custom_ops/xpu_ops/test/test_speculate_get_logits.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import speculate_get_logits
|
||||
|
||||
# 固定随机种子,保证测试可复现
|
||||
np.random.seed(2023)
|
||||
paddle.seed(2023)
|
||||
|
||||
|
||||
def generate_test_data():
|
||||
"""
|
||||
生成测试数据的辅助函数。
|
||||
这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。
|
||||
"""
|
||||
real_bsz = 64
|
||||
vocab_size = 2 * 1024
|
||||
max_seq_len = 8 * 1024
|
||||
|
||||
# 生成原始测试数据(完全复用原有逻辑)
|
||||
seq_lens_encoder = np.random.randint(0, 2, [real_bsz], dtype=np.int32)
|
||||
seq_lens_this_time = np.random.randint(1, max_seq_len, [real_bsz], dtype=np.int32)
|
||||
draft_logits_seqlen = 0
|
||||
logits_seqlen = 0
|
||||
for i in range(real_bsz):
|
||||
if seq_lens_encoder[i] > 0:
|
||||
draft_logits_seqlen += 2
|
||||
logits_seqlen += 1
|
||||
else:
|
||||
draft_logits_seqlen += seq_lens_this_time[i]
|
||||
logits_seqlen += seq_lens_this_time[i]
|
||||
|
||||
draft_logits = np.zeros([draft_logits_seqlen, vocab_size], dtype=np.float32)
|
||||
next_token_num = np.zeros([real_bsz], dtype=np.int32)
|
||||
batch_token_num = np.zeros([real_bsz], dtype=np.int32)
|
||||
cu_next_token_offset = np.zeros([real_bsz], dtype=np.int32)
|
||||
cu_batch_token_offset = np.zeros([real_bsz], dtype=np.int32)
|
||||
logits = np.random.rand(logits_seqlen, vocab_size).astype(np.float32)
|
||||
first_token_logits = np.random.rand(real_bsz, vocab_size).astype(np.float32)
|
||||
|
||||
paddle.set_device("cpu")
|
||||
# 转换为 paddle tensor(保持原有逻辑)
|
||||
data_cpu = {
|
||||
"draft_logits": paddle.to_tensor(draft_logits),
|
||||
"next_token_num": paddle.to_tensor(next_token_num),
|
||||
"batch_token_num": paddle.to_tensor(batch_token_num),
|
||||
"cu_next_token_offset": paddle.to_tensor(cu_next_token_offset),
|
||||
"cu_batch_token_offset": paddle.to_tensor(cu_batch_token_offset),
|
||||
"logits": paddle.to_tensor(logits),
|
||||
"first_token_logits": paddle.to_tensor(first_token_logits),
|
||||
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
|
||||
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
|
||||
}
|
||||
|
||||
paddle.set_device("xpu:0")
|
||||
data_xpu = {
|
||||
"draft_logits": paddle.to_tensor(draft_logits),
|
||||
"next_token_num": paddle.to_tensor(next_token_num),
|
||||
"batch_token_num": paddle.to_tensor(batch_token_num),
|
||||
"cu_next_token_offset": paddle.to_tensor(cu_next_token_offset),
|
||||
"cu_batch_token_offset": paddle.to_tensor(cu_batch_token_offset),
|
||||
"logits": paddle.to_tensor(logits),
|
||||
"first_token_logits": paddle.to_tensor(first_token_logits),
|
||||
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
|
||||
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
|
||||
}
|
||||
|
||||
# 恢复默认设备,避免影响其他测试
|
||||
paddle.set_device("cpu")
|
||||
|
||||
return data_cpu, data_xpu
|
||||
|
||||
|
||||
def speculate_get_logits_execution(test_data):
|
||||
"""测试函数的执行性和输出合理性"""
|
||||
|
||||
# 执行目标函数(核心测试步骤)
|
||||
speculate_get_logits(**test_data)
|
||||
|
||||
return test_data
|
||||
|
||||
|
||||
class TestSpeculateGetLogits(unittest.TestCase):
|
||||
"""
|
||||
测试类,继承自 unittest.TestCase。
|
||||
所有以 'test_' 开头的方法都会被视为测试用例。
|
||||
"""
|
||||
|
||||
def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08, target_keys=None):
|
||||
"""
|
||||
自定义的断言方法,用于比较两个 test_data 结构和数据。
|
||||
在 unittest 中,自定义断言通常以 'assert' 开头。
|
||||
"""
|
||||
# 1. 先校验两个 test_data 的字段名完全一致
|
||||
keys1 = set(test_data1.keys())
|
||||
keys2 = set(test_data2.keys())
|
||||
self.assertEqual(
|
||||
keys1,
|
||||
keys2,
|
||||
msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}",
|
||||
)
|
||||
|
||||
# 2. 逐字段校验数据
|
||||
if target_keys is not None and isinstance(target_keys, list):
|
||||
local_target_key = target_keys
|
||||
else:
|
||||
local_target_key = keys1
|
||||
for key in local_target_key:
|
||||
data1 = test_data1[key]
|
||||
data2 = test_data2[key]
|
||||
|
||||
# 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用)
|
||||
if isinstance(data1, paddle.Tensor):
|
||||
np1 = data1.detach().cpu().numpy()
|
||||
else:
|
||||
np1 = np.asarray(data1)
|
||||
|
||||
if isinstance(data2, paddle.Tensor):
|
||||
np2 = data2.detach().cpu().numpy()
|
||||
else:
|
||||
np2 = np.asarray(data2)
|
||||
|
||||
# 3. 校验数据
|
||||
if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8):
|
||||
# 布尔/整数型:必须完全相等
|
||||
np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!")
|
||||
else:
|
||||
# 浮点型:允许 rtol/atol 范围内的误差
|
||||
np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!")
|
||||
|
||||
print("✅ 两个 test_data 结构和数据完全一致!")
|
||||
|
||||
def test_speculate_get_logits(self):
|
||||
"""
|
||||
核心测试用例方法。
|
||||
该方法会调用 generate_test_data 获取数据,
|
||||
分别在 CPU 和 XPU 上执行测试函数,
|
||||
并使用自定义的断言方法比较结果。
|
||||
"""
|
||||
print("\nRunning test: test_speculate_get_logits")
|
||||
|
||||
# 1. 获取测试数据
|
||||
data_cpu, data_xpu = generate_test_data()
|
||||
|
||||
# 2. 执行测试函数
|
||||
result_xpu = speculate_get_logits_execution(data_xpu)
|
||||
result_cpu = speculate_get_logits_execution(data_cpu)
|
||||
|
||||
# 3. 断言结果一致
|
||||
target_keys = ["draft_logits", "batch_token_num", "cu_batch_token_offset"]
|
||||
self.assert_test_data_equal(result_cpu, result_xpu, target_keys=target_keys)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用 unittest 的主程序来运行所有测试用例
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user