mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* support mm prefill batch * update code * update code * update code * update code * fix encoder cache bug * update code * update code * fix bug * fix paddle ocr bug * fix xpu bug * update code
54 lines
2.2 KiB
Python
54 lines
2.2 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
from fastdeploy.worker.output import LogprobsTensors
|
|
|
|
|
|
class TestLogprobsOutput(unittest.TestCase):
|
|
|
|
def test_logprobs_output(self):
|
|
num_positions = 3
|
|
num_tokens_per_position = 4
|
|
shape = [num_positions, num_tokens_per_position]
|
|
logprobs_tensors = LogprobsTensors.empty(num_positions, num_tokens_per_position)
|
|
assert logprobs_tensors.logprob_token_ids.shape == shape
|
|
assert logprobs_tensors.logprobs.shape == shape
|
|
assert logprobs_tensors.selected_token_ranks.shape == [num_positions]
|
|
|
|
sliced_logprobs_tensors = logprobs_tensors.slice_rows(1, 2)
|
|
assert sliced_logprobs_tensors.logprob_token_ids.shape == [1, num_tokens_per_position]
|
|
assert sliced_logprobs_tensors.logprobs.shape == [1, num_tokens_per_position]
|
|
assert sliced_logprobs_tensors.selected_token_ranks.shape == [1]
|
|
|
|
logprobs_tensors_cpu = LogprobsTensors.empty_cpu(num_positions, num_tokens_per_position)
|
|
assert logprobs_tensors_cpu.logprob_token_ids.shape == shape
|
|
assert logprobs_tensors_cpu.logprobs.shape == shape
|
|
assert logprobs_tensors_cpu.selected_token_ranks.shape == [num_positions]
|
|
|
|
logprobs_list = logprobs_tensors_cpu.tolists()
|
|
assert isinstance(logprobs_list.logprobs, list)
|
|
assert len(logprobs_list.logprobs) == num_positions
|
|
|
|
row_sliced_logprobs_list = logprobs_list.slice_rows(1, 2)
|
|
assert len(row_sliced_logprobs_list.logprobs) == 1
|
|
|
|
col_sliced_logprobs_list = logprobs_list.slice_columns(1, 2)
|
|
assert len(col_sliced_logprobs_list.logprobs) == num_positions
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|