Files
FastDeploy/custom_ops/xpu_ops/test/test_draft_model_update.py

123 lines
4.4 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import draft_model_update
def run_paddle_test(device="cpu"):
np.random.seed(42)
paddle.seed(42)
if device == "cpu":
paddle.set_device(device)
elif device == "xpu":
paddle.set_device(device)
else:
raise ValueError(f"Invalid device: {device}")
# 设置参数
max_bsz = 128
max_draft_token = 3
pre_id_length = 3
max_seq_len = 100
max_base_model_draft_token = 4
substep = 2
# 创建随机张量
inter_next_tokens = paddle.randint(1, 100, shape=(max_bsz, max_seq_len), dtype="int64")
draft_tokens = paddle.randint(1, 100, shape=(max_bsz, max_draft_token), dtype="int64")
pre_ids = paddle.randint(1, 100, shape=(max_bsz, pre_id_length), dtype="int64")
seq_lens_this_time = paddle.randint(1, 2, shape=(max_bsz,), dtype="int32")
seq_lens_encoder = paddle.randint(1, 10, shape=(max_bsz,), dtype="int32")
seq_lens_decoder = paddle.randint(1, 10, shape=(max_bsz,), dtype="int32")
step_idx = paddle.randint(1, 10, shape=(max_bsz,), dtype="int64")
output_cum_offsets = paddle.randint(0, 2, shape=(max_bsz,), dtype="int32")
output_cum_offsets[0] = 0 # 确保第一个偏移量为0
stop_flags = paddle.zeros([max_bsz], dtype="bool")
not_need_stop = paddle.zeros([1], dtype="bool")
max_dec_len = paddle.randint(100, 102, shape=(max_bsz,), dtype="int64")
end_ids = paddle.to_tensor([2], dtype="int64")
base_model_draft_tokens = paddle.randint(1, 10, shape=(max_bsz, max_base_model_draft_token), dtype="int64")
# 打印张量信息
# print("inter_next_tokens shape:", inter_next_tokens.shape)
# print("draft_tokens shape:", draft_tokens.shape)
# print("pre_ids shape:", pre_ids.shape)
# print("seq_lens_this_time shape:", seq_lens_this_time.shape)
# print("seq_lens_encoder shape:", seq_lens_encoder.shape)
# print("seq_lens_decoder shape:", seq_lens_decoder.shape)
# print("step_idx shape:", step_idx.shape)
# print("output_cum_offsets shape:", output_cum_offsets.shape)
# print("stop_flags shape:", stop_flags.shape)
# print("not_need_stop shape:", not_need_stop.shape)
# print("max_dec_len shape:", max_dec_len.shape)
# print("end_ids shape:", end_ids.shape)
# print("base_model_draft_tokens shape:", base_model_draft_tokens.shape)
# print("draft_tokens before update:", draft_tokens)
# print("pre_ids before update:", pre_ids)
draft_model_update(
inter_next_tokens,
draft_tokens,
pre_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
output_cum_offsets,
stop_flags,
not_need_stop,
max_dec_len,
end_ids,
base_model_draft_tokens,
max_seq_len,
substep,
)
# print("draft_tokens after update:", draft_tokens)
# print("pre_ids after update:", pre_ids)
return (
draft_tokens,
pre_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
not_need_stop,
base_model_draft_tokens,
)
if __name__ == "__main__":
res_xpu = run_paddle_test("xpu")
res_cpu = run_paddle_test()
for idx in range(len(res_cpu)):
# 将结果转换为numpy数组
cpu_arr = res_cpu[idx].numpy()
xpu_arr = res_xpu[idx].numpy()
# 检查是否为布尔类型
if cpu_arr.dtype == bool:
assert np.array_equal(cpu_arr, xpu_arr), f"布尔结果在索引 {idx} 处不匹配"
else:
# 对于数值类型,使用更宽松的比较条件
assert np.allclose(
cpu_arr, xpu_arr, rtol=1e-4, atol=1e-5
), f"数值结果在索引 {idx} 处不匹配,最大差异: {np.max(np.abs(cpu_arr - xpu_arr))}"
print(f"结果 {idx} 验证通过")