[XPU] add speculate_get_logits (#5497)

* [XPU] add speculate_step_system_cache

* [XPU] add speculate_step_system_cache

* [XPU] add speculate_get_logits

* delete context

* add ptr check

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
RuohengMa
2025-12-12 15:38:30 +08:00
committed by GitHub
parent 888c4b992d
commit 12c76f8137
6 changed files with 603 additions and 0 deletions

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <stdio.h>
#include "paddle/common/flags.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "xpu/internal/infra_op.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
const paddle::Tensor& next_token_num,
const paddle::Tensor& batch_token_num,
const paddle::Tensor& cu_next_token_offset,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& logits,
const paddle::Tensor& first_token_logits,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
baidu::xpu::api::Context* ctx =
static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
if (draft_logits.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
}
const int vocab_size = logits.shape()[1];
const int real_bsz = seq_lens_this_time.shape()[0];
baidu::xpu::api::plugin::speculate_get_logits(
ctx,
const_cast<float*>(draft_logits.data<float>()),
const_cast<int*>(next_token_num.data<int>()),
const_cast<int*>(batch_token_num.data<int>()),
const_cast<int*>(cu_next_token_offset.data<int>()),
const_cast<int*>(cu_batch_token_offset.data<int>()),
logits.data<float>(),
first_token_logits.data<float>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
real_bsz,
vocab_size);
if (draft_logits.is_cpu()) {
delete ctx;
}
}
PD_BUILD_STATIC_OP(speculate_get_logits)
.Inputs({"draft_logits",
"next_token_num",
"batch_token_num",
"cu_next_token_offset",
"cu_batch_token_offset",
"logits",
"first_token_logits",
"seq_lens_this_time",
"seq_lens_encoder"})
.Outputs({"draft_logits_out",
"batch_token_num_out",
"cu_batch_token_offset_out"})
.SetInplaceMap({{"draft_logits", "draft_logits_out"},
{"batch_token_num", "batch_token_num_out"},
{"cu_batch_token_offset", "cu_batch_token_offset_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetLogits));

View File

@@ -470,6 +470,16 @@ void SpeculateStepPaddle(
const int encoder_decoder_block_num,
const int max_draft_tokens);
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
const paddle::Tensor& next_token_num,
const paddle::Tensor& batch_token_num,
const paddle::Tensor& cu_next_token_offset,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& logits,
const paddle::Tensor& first_token_logits,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder);
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
@@ -1174,6 +1184,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_draft_tokens"),
"Step paddle function");
m.def("speculate_get_logits",
&SpeculateGetLogits,
py::arg("draft_logits"),
py::arg("next_token_num"),
py::arg("batch_token_num"),
py::arg("cu_next_token_offset"),
py::arg("cu_batch_token_offset"),
py::arg("logits"),
py::arg("first_token_logits"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
"speculate get logits function");
m.def("text_image_gather_scatter",
&TextImageGatherScatter,
py::arg("input"),

View File

@@ -600,6 +600,19 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
T* output,
int dim_embed,
int elem_cnt);
DLL_EXPORT int speculate_get_logits(Context* ctx,
float* draft_logits,
int* next_token_num,
int* batch_token_num,
int* cu_next_token_offset,
int* cu_batch_token_offset,
const float* logits,
const float* first_token_logits,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size);
/*--------------------------------------- MTP end
* --------------------------------------------*/

View File

@@ -0,0 +1,131 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
__device__ void prefix_sum(__shared_ptr__ int* sm_seq_lens_encoder,
__shared_ptr__ int* sm_seq_lens_this_time,
__shared_ptr__ int* sm_batch_token_num,
__shared_ptr__ int* sm_cu_batch_token_offset,
__shared_ptr__ int* sm_cu_next_token_offset,
__global_ptr__ int* batch_token_num,
__global_ptr__ int* cu_batch_token_offset,
__global_ptr__ const int* seq_lens_this_time,
__global_ptr__ const int* seq_lens_encoder,
const int real_bsz) {
int cid = core_id();
int clus_id = cluster_id();
if (clus_id < real_bsz && cid == 0) {
GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int));
GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int));
int next_token_num_previous = 0;
for (int bid = 0; bid < real_bsz; bid++) {
sm_batch_token_num[bid] =
sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid];
if (bid == 0) {
sm_cu_batch_token_offset[bid] = 0;
sm_cu_next_token_offset[bid] = 0;
} else {
sm_cu_batch_token_offset[bid] =
sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1];
sm_cu_next_token_offset[bid] =
sm_cu_next_token_offset[bid - 1] + next_token_num_previous;
}
next_token_num_previous =
sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid];
}
mfence_sm();
if (clus_id == 0) {
SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int));
SM2GM_ASYNC(sm_cu_batch_token_offset,
cu_batch_token_offset,
real_bsz * sizeof(int));
}
}
mfence_sm();
sync_all();
}
__global__ void speculate_get_logits(float* draft_logits,
int* next_token_num,
int* batch_token_num,
int* cu_next_token_offset,
int* cu_batch_token_offset,
const float* logits,
const float* first_token_logits,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size) {
int cid = core_id();
int ncores = core_num();
int clus_id = cluster_id();
int nclusters = cluster_num();
int lm_size = 2 * 1024;
int lm_buf_len = lm_size / sizeof(float);
float first_token_logits_now_lm[lm_buf_len];
float logits_now_lm[lm_buf_len];
const int sm_size = 256 * 1024;
__shared__ char sm[sm_size];
int sm_max_buf_len = 256 * 1024 / sizeof(int);
sm_max_buf_len /= 5;
__shared_ptr__ int* sm_seq_lens_encoder = (__shared_ptr__ int*)sm;
__shared_ptr__ int* sm_seq_lens_this_time =
sm_seq_lens_encoder + sm_max_buf_len;
__shared_ptr__ int* sm_batch_token_num =
sm_seq_lens_this_time + sm_max_buf_len;
__shared_ptr__ int* sm_cu_batch_token_offset =
sm_batch_token_num + sm_max_buf_len;
__shared_ptr__ int* sm_cu_next_token_offset =
sm_cu_batch_token_offset + sm_max_buf_len;
prefix_sum(sm_seq_lens_encoder,
sm_seq_lens_this_time,
sm_batch_token_num,
sm_cu_batch_token_offset,
sm_cu_next_token_offset,
batch_token_num,
cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
real_bsz);
for (int bid = clus_id; bid < real_bsz; bid += nclusters) {
auto* draft_logits_now =
draft_logits + sm_cu_batch_token_offset[bid] * vocab_size;
auto* logits_now = logits + sm_cu_next_token_offset[bid] * vocab_size;
auto* first_token_logits_now = first_token_logits + bid * vocab_size;
for (int i = cid * lm_buf_len; i < vocab_size; i += ncores * lm_buf_len) {
int read_len = min(lm_buf_len, vocab_size - i);
if (sm_seq_lens_encoder[bid] > 0) {
GM2LM_ASYNC(first_token_logits_now + i,
first_token_logits_now_lm,
read_len * sizeof(float));
GM2LM(logits_now + i, logits_now_lm, read_len * sizeof(float));
LM2GM_ASYNC(first_token_logits_now_lm,
draft_logits_now + i,
read_len * sizeof(float));
LM2GM(logits_now_lm,
draft_logits_now + vocab_size + i,
read_len * sizeof(float));
} else {
for (int j = 0; j < sm_seq_lens_this_time[bid]; j++) {
GM2LM(logits_now + j * vocab_size + i,
logits_now_lm,
read_len * sizeof(float));
LM2GM(logits_now_lm,
draft_logits_now + j * vocab_size + i,
read_len * sizeof(float));
}
}
}
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,184 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_get_logits(
float* draft_logits,
int* next_token_num,
int* batch_token_num,
int* cu_next_token_offset,
int* cu_batch_token_offset,
const float* logits,
const float* first_token_logits,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(float* draft_logits,
int* next_token_num,
int* batch_token_num,
int* cu_next_token_offset,
int* cu_batch_token_offset,
const float* logits,
const float* first_token_logits,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size) {
int batch_token_num_sum = 0;
int next_token_num_sum = 0;
for (int bid = 0; bid < real_bsz; bid++) {
// prefix sum
cu_batch_token_offset[bid] = batch_token_num_sum;
cu_next_token_offset[bid] = next_token_num_sum;
batch_token_num[bid] =
seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid];
next_token_num[bid] =
seq_lens_encoder[bid] > 0 ? 1 : seq_lens_this_time[bid];
batch_token_num_sum += batch_token_num[bid];
next_token_num_sum += next_token_num[bid];
auto* draft_logits_now =
draft_logits + cu_batch_token_offset[bid] * vocab_size;
auto* logits_now = logits + cu_next_token_offset[bid] * vocab_size;
auto* first_token_logits_now = first_token_logits + bid * vocab_size;
for (int i = 0; i < vocab_size; i++) {
if (seq_lens_encoder[bid] > 0) {
draft_logits_now[i] = first_token_logits_now[i];
draft_logits_now[vocab_size + i] = logits_now[i];
} else {
for (int j = 0; j < seq_lens_this_time[bid]; j++) {
draft_logits_now[j * vocab_size + i] = logits_now[j * vocab_size + i];
}
}
}
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
float* draft_logits,
int* next_token_num,
int* batch_token_num,
int* cu_next_token_offset,
int* cu_batch_token_offset,
const float* logits,
const float* first_token_logits,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size) {
xpu3::plugin::speculate_get_logits<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
draft_logits,
next_token_num,
batch_token_num,
cu_next_token_offset,
cu_batch_token_offset,
logits,
first_token_logits,
seq_lens_this_time,
seq_lens_encoder,
real_bsz,
vocab_size);
return api::SUCCESS;
}
int speculate_get_logits(Context* ctx,
float* draft_logits,
int* next_token_num,
int* batch_token_num,
int* cu_next_token_offset,
int* cu_batch_token_offset,
const float* logits,
const float* first_token_logits,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_logits", float);
WRAPPER_DUMP_PARAM6(ctx,
draft_logits,
next_token_num,
batch_token_num,
cu_next_token_offset,
cu_batch_token_offset,
logits);
WRAPPER_DUMP_PARAM5(ctx,
first_token_logits,
seq_lens_this_time,
seq_lens_encoder,
real_bsz,
vocab_size);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, next_token_num);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, batch_token_num);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_next_token_offset);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_batch_token_offset);
WRAPPER_CHECK_PTR(ctx, float, real_bsz* vocab_size, first_token_logits);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder);
WRAPPER_ASSERT_LE(ctx, real_bsz, 256 * 1024 / sizeof(int) / 5);
WRAPPER_ASSERT_GT(ctx, vocab_size, 0);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(draft_logits,
next_token_num,
batch_token_num,
cu_next_token_offset,
cu_batch_token_offset,
logits,
first_token_logits,
seq_lens_this_time,
seq_lens_encoder,
real_bsz,
vocab_size);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
draft_logits,
next_token_num,
batch_token_num,
cu_next_token_offset,
cu_batch_token_offset,
logits,
first_token_logits,
seq_lens_this_time,
seq_lens_encoder,
real_bsz,
vocab_size);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,172 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_get_logits
# 固定随机种子,保证测试可复现
np.random.seed(2023)
paddle.seed(2023)
def generate_test_data():
"""
生成测试数据的辅助函数。
这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。
"""
real_bsz = 64
vocab_size = 2 * 1024
max_seq_len = 8 * 1024
# 生成原始测试数据(完全复用原有逻辑)
seq_lens_encoder = np.random.randint(0, 2, [real_bsz], dtype=np.int32)
seq_lens_this_time = np.random.randint(1, max_seq_len, [real_bsz], dtype=np.int32)
draft_logits_seqlen = 0
logits_seqlen = 0
for i in range(real_bsz):
if seq_lens_encoder[i] > 0:
draft_logits_seqlen += 2
logits_seqlen += 1
else:
draft_logits_seqlen += seq_lens_this_time[i]
logits_seqlen += seq_lens_this_time[i]
draft_logits = np.zeros([draft_logits_seqlen, vocab_size], dtype=np.float32)
next_token_num = np.zeros([real_bsz], dtype=np.int32)
batch_token_num = np.zeros([real_bsz], dtype=np.int32)
cu_next_token_offset = np.zeros([real_bsz], dtype=np.int32)
cu_batch_token_offset = np.zeros([real_bsz], dtype=np.int32)
logits = np.random.rand(logits_seqlen, vocab_size).astype(np.float32)
first_token_logits = np.random.rand(real_bsz, vocab_size).astype(np.float32)
paddle.set_device("cpu")
# 转换为 paddle tensor保持原有逻辑
data_cpu = {
"draft_logits": paddle.to_tensor(draft_logits),
"next_token_num": paddle.to_tensor(next_token_num),
"batch_token_num": paddle.to_tensor(batch_token_num),
"cu_next_token_offset": paddle.to_tensor(cu_next_token_offset),
"cu_batch_token_offset": paddle.to_tensor(cu_batch_token_offset),
"logits": paddle.to_tensor(logits),
"first_token_logits": paddle.to_tensor(first_token_logits),
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
}
paddle.set_device("xpu:0")
data_xpu = {
"draft_logits": paddle.to_tensor(draft_logits),
"next_token_num": paddle.to_tensor(next_token_num),
"batch_token_num": paddle.to_tensor(batch_token_num),
"cu_next_token_offset": paddle.to_tensor(cu_next_token_offset),
"cu_batch_token_offset": paddle.to_tensor(cu_batch_token_offset),
"logits": paddle.to_tensor(logits),
"first_token_logits": paddle.to_tensor(first_token_logits),
"seq_lens_this_time": paddle.to_tensor(seq_lens_this_time),
"seq_lens_encoder": paddle.to_tensor(seq_lens_encoder),
}
# 恢复默认设备,避免影响其他测试
paddle.set_device("cpu")
return data_cpu, data_xpu
def speculate_get_logits_execution(test_data):
"""测试函数的执行性和输出合理性"""
# 执行目标函数(核心测试步骤)
speculate_get_logits(**test_data)
return test_data
class TestSpeculateGetLogits(unittest.TestCase):
"""
测试类,继承自 unittest.TestCase。
所有以 'test_' 开头的方法都会被视为测试用例。
"""
def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08, target_keys=None):
"""
自定义的断言方法,用于比较两个 test_data 结构和数据。
在 unittest 中,自定义断言通常以 'assert' 开头。
"""
# 1. 先校验两个 test_data 的字段名完全一致
keys1 = set(test_data1.keys())
keys2 = set(test_data2.keys())
self.assertEqual(
keys1,
keys2,
msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}",
)
# 2. 逐字段校验数据
if target_keys is not None and isinstance(target_keys, list):
local_target_key = target_keys
else:
local_target_key = keys1
for key in local_target_key:
data1 = test_data1[key]
data2 = test_data2[key]
# 区分paddle Tensor需转 numpy和 普通标量/数组(直接使用)
if isinstance(data1, paddle.Tensor):
np1 = data1.detach().cpu().numpy()
else:
np1 = np.asarray(data1)
if isinstance(data2, paddle.Tensor):
np2 = data2.detach().cpu().numpy()
else:
np2 = np.asarray(data2)
# 3. 校验数据
if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8):
# 布尔/整数型:必须完全相等
np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!")
else:
# 浮点型:允许 rtol/atol 范围内的误差
np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!")
print("✅ 两个 test_data 结构和数据完全一致!")
def test_speculate_get_logits(self):
"""
核心测试用例方法。
该方法会调用 generate_test_data 获取数据,
分别在 CPU 和 XPU 上执行测试函数,
并使用自定义的断言方法比较结果。
"""
print("\nRunning test: test_speculate_get_logits")
# 1. 获取测试数据
data_cpu, data_xpu = generate_test_data()
# 2. 执行测试函数
result_xpu = speculate_get_logits_execution(data_xpu)
result_cpu = speculate_get_logits_execution(data_cpu)
# 3. 断言结果一致
target_keys = ["draft_logits", "batch_token_num", "cu_batch_token_offset"]
self.assert_test_data_equal(result_cpu, result_xpu, target_keys=target_keys)
if __name__ == "__main__":
# 使用 unittest 的主程序来运行所有测试用例
unittest.main()