Files
FastDeploy/tests/benchmarks/test_datasets_benchmarks.py
Echo-Nie 1b1bfab341 [CI] Add unittest (#5328)
* add test_worker_eplb

* remove tesnsor_wise_fp8

* add copyright
2025-12-09 19:19:42 +08:00

166 lines
5.3 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
from argparse import ArgumentParser, Namespace
import pytest
from PIL import Image
import fastdeploy.benchmarks.datasets as bd
class DummyTokenizer:
vocab_size = 100
def num_special_tokens_to_add(self):
return 1
def decode(self, ids):
return "dummy_text"
def encode(self, text, add_special_tokens=False):
return list(range(len(text)))
def make_temp_json(tmp_path, content):
fpath = tmp_path / "data.json"
with open(fpath, "w", encoding="utf-8") as f:
for line in content:
f.write(json.dumps(line) + "\n")
return str(fpath)
def test_is_valid_sequence_variants():
assert bd.is_valid_sequence(10, 10)
assert not bd.is_valid_sequence(1, 10) # prompt too short
assert not bd.is_valid_sequence(10, 1) # output too short
assert not bd.is_valid_sequence(2000, 10, max_prompt_len=100)
assert not bd.is_valid_sequence(2000, 100, max_total_len=200)
# skip min output len
assert bd.is_valid_sequence(10, 1, skip_min_output_len_check=True)
def test_process_image_with_pil_and_str(tmp_path):
# dict input with raw bytes
img = Image.new("RGB", (10, 10), color="red")
buf = io.BytesIO()
img.save(buf, format="PNG")
raw_dict = {"bytes": buf.getvalue()}
out = bd.process_image(raw_dict)
assert "image_url" in out
# PIL image input
out2 = bd.process_image(img)
assert out2["type"] == "image_url"
assert out2["image_url"]["url"].startswith("data:image/jpeg;base64,")
# str input
out3 = bd.process_image("path/to/file")
assert out3["image_url"]["url"].startswith("file://")
out4 = bd.process_image("http://abc.com/img.png")
assert out4["image_url"]["url"].startswith("http://")
# invalid input
with pytest.raises(ValueError):
bd.process_image(123)
def test_maybe_oversample_requests(caplog):
dataset = bd.RandomDataset()
requests = [bd.SampleRequest(1, "a", [], None, 10, 20)]
dataset.maybe_oversample_requests(requests, 3)
assert len(requests) >= 3
def test_EBDataset_and_EBChatDataset(tmp_path):
eb_content = [
{
"text": "hello",
"temperature": 0.7,
"penalty_score": 1.0,
"frequency_score": 1.0,
"presence_score": 1.0,
"topp": 0.9,
"input_token_num": 5,
"max_dec_len": 10,
}
]
eb_file = make_temp_json(tmp_path, eb_content)
eb = bd.EBDataset(dataset_path=eb_file, shuffle=True)
samples = eb.sample(2)
assert all(isinstance(s, bd.SampleRequest) for s in samples)
assert all(s.json_data is not None for s in samples)
chat_content = [{"messages": [{"role": "user", "content": "hi"}], "max_tokens": 20}]
chat_file = make_temp_json(tmp_path, chat_content)
chat = bd.EBChatDataset(dataset_path=chat_file, shuffle=True)
samples2 = chat.sample(2, enable_multimodal_chat=False)
assert all(isinstance(s, bd.SampleRequest) for s in samples2)
assert all(s.json_data is not None for s in samples2)
def test_RandomDataset_sample():
tok = DummyTokenizer()
dataset = bd.RandomDataset(random_seed=123)
samples = dataset.sample(tok, 2, prefix_len=2, range_ratio=0.1)
assert len(samples) == 2
assert all(isinstance(s, bd.SampleRequest) for s in samples)
# range_ratio >= 1 should raise
with pytest.raises(AssertionError):
dataset.sample(tok, 1, range_ratio=1.0)
def test__ValidateDatasetArgs_and_get_samples(tmp_path):
parser = ArgumentParser()
parser.add_argument("--dataset-name", default="random")
parser.add_argument("--dataset-path", action=bd._ValidateDatasetArgs)
# invalid: random + dataset-path
with pytest.raises(SystemExit):
parser.parse_args(["--dataset-path", "abc.json"])
# test get_samples with EBChat
chat_content = [
{
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
{"role": "user", "content": "how are you?"},
],
"max_tokens": 10,
}
]
chat_file = make_temp_json(tmp_path, chat_content)
args = Namespace(
dataset_name="EBChat", dataset_path=chat_file, seed=0, shuffle=False, num_prompts=1, sharegpt_output_len=10
)
out = bd.get_samples(args)
assert isinstance(out, list)
# unknown dataset
args.dataset_name = "unknown"
with pytest.raises(ValueError):
bd.get_samples(args)
def test_add_dataset_parser():
parser = bd.FlexibleArgumentParser()
bd.add_dataset_parser(parser)
args = parser.parse_args([])
assert hasattr(args, "seed")
assert hasattr(args, "num_prompts")