mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
165 lines
6.1 KiB
Python
165 lines
6.1 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License")
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
from fastdeploy.model_executor.ops.gpu import eagle_get_hidden_states
|
|
|
|
|
|
def ComputeOrderKernel(
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
accept_nums,
|
|
position_map,
|
|
output_token_num,
|
|
bsz,
|
|
actual_draft_token_num,
|
|
input_token_num,
|
|
):
|
|
in_offset = 0
|
|
out_offset = 0
|
|
for i in range(bsz):
|
|
cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i]
|
|
# cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i]
|
|
cur_seq_lens_this_time = seq_lens_this_time[i]
|
|
accept_num = accept_nums[i]
|
|
cur_seq_lens_encoder = seq_lens_encoder[i]
|
|
# 1. eagle encoder. Base step=1
|
|
if cur_seq_lens_encoder > 0:
|
|
for j in range(cur_seq_lens_encoder):
|
|
position_map[in_offset] = out_offset
|
|
in_offset += 1
|
|
out_offset += 1
|
|
# 2. Base model stop at last verify-step.
|
|
elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0:
|
|
in_offset += cur_base_model_seq_lens_this_time
|
|
# 4. stopped
|
|
elif cur_base_model_seq_lens_this_time == 0 and cur_seq_lens_this_time == 0: # end
|
|
pass
|
|
else:
|
|
for i in range(accept_num):
|
|
position_map[in_offset] = out_offset
|
|
in_offset += 1
|
|
out_offset += 1
|
|
in_offset += cur_base_model_seq_lens_this_time - accept_num
|
|
output_token_num[0] = out_offset
|
|
|
|
|
|
def rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt):
|
|
for elem_idx in range(elem_cnt):
|
|
ori_token_idx = int(elem_idx / dim_embed)
|
|
token_idx = position_map[ori_token_idx]
|
|
if token_idx >= 0:
|
|
offset = elem_idx % dim_embed
|
|
out[token_idx][offset] = input[ori_token_idx][offset]
|
|
|
|
|
|
def eagle_get_hidden_states_ref(
|
|
input,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
stop_flags,
|
|
accept_nums,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
actual_draft_token_num,
|
|
):
|
|
input_token_num = input.shape[0]
|
|
dim_embed = input.shape[1]
|
|
bsz = seq_lens_this_time.shape[0]
|
|
position_map = paddle.full([input_token_num], 0xFFFFFFFF, seq_lens_this_time.dtype)
|
|
output_token_num = paddle.empty([1], seq_lens_this_time.dtype)
|
|
ComputeOrderKernel(
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
accept_nums,
|
|
position_map,
|
|
output_token_num,
|
|
bsz,
|
|
actual_draft_token_num,
|
|
input_token_num,
|
|
)
|
|
|
|
output_token_num_cpu = output_token_num[0]
|
|
out = paddle.empty([output_token_num_cpu, dim_embed], input.dtype)
|
|
elem_cnt = input_token_num * dim_embed
|
|
rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt)
|
|
return out
|
|
|
|
|
|
class TestEagleGetHiddenStates(unittest.TestCase):
|
|
def test_eagle_get_hidden_states(self):
|
|
np.random.seed(2023)
|
|
paddle.seed(2023)
|
|
bs = 2
|
|
input_token_num = 10
|
|
dim_embed = 512
|
|
actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32)
|
|
|
|
seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32)
|
|
seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
|
|
accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32)
|
|
base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
|
|
base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32)
|
|
|
|
seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
|
|
stop_flags = np.random.randint(0, 2, bs, dtype=np.int32)
|
|
|
|
seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
|
|
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32)
|
|
accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32)
|
|
base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32)
|
|
base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32)
|
|
|
|
seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32)
|
|
stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32)
|
|
|
|
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
|
|
input_tensor = paddle.to_tensor(input, dtype=paddle.float16)
|
|
out = eagle_get_hidden_states(
|
|
input_tensor,
|
|
seq_lens_this_time_tensor,
|
|
seq_lens_encoder_tensor,
|
|
seq_lens_decoder_tensor,
|
|
stop_flags_tensor,
|
|
accept_nums_tensor,
|
|
base_model_seq_lens_this_time_tensor,
|
|
base_model_seq_lens_encoder_tensor,
|
|
actual_draft_token_num,
|
|
)
|
|
out_ref = eagle_get_hidden_states_ref(
|
|
input_tensor,
|
|
seq_lens_this_time_tensor,
|
|
seq_lens_encoder_tensor,
|
|
seq_lens_decoder_tensor,
|
|
stop_flags_tensor,
|
|
accept_nums_tensor,
|
|
base_model_seq_lens_this_time_tensor,
|
|
base_model_seq_lens_encoder_tensor,
|
|
actual_draft_token_num,
|
|
)
|
|
np.testing.assert_allclose(out.numpy(), out_ref.numpy())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|