Files
FastDeploy/tests/entrypoints/cli/test_tokenizer_cli.py
2025-10-31 13:57:51 +08:00

337 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Test cases for tokenizer CLI
"""
import argparse
import json
import os
import tempfile
import unittest
from unittest.mock import MagicMock, PropertyMock, patch
from fastdeploy.entrypoints.cli.tokenizer import main
def model_path():
"""
Get model path from environment variable MODEL_PATH,
default to "./ERNIE-4.5-0.3B-Paddle" if not set.
"""
base_path = os.getenv("MODEL_PATH")
if base_path:
return os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
return "./ERNIE-4.5-0.3B-Paddle"
class MockCLISubcommand:
"""模拟CLISubcommand基类"""
pass
class MockInputPreprocessor:
"""模拟InputPreprocessor类"""
def __init__(self, model_name_or_path):
self.model_name_or_path = model_name_or_path
def create_processor(self):
mock_processor = MagicMock()
mock_processor.tokenizer = MagicMock()
return mock_processor
# 导入被测试代码,使用模拟的依赖
with patch("fastdeploy.entrypoints.cli.types.CLISubcommand", MockCLISubcommand):
with patch("fastdeploy.input.preprocess.InputPreprocessor", MockInputPreprocessor):
# 这里直接包含被测试的代码内容
from fastdeploy.entrypoints.cli.tokenizer import (
TokenizerSubcommand,
cmd_init,
export_vocabulary,
get_tokenizer_info,
get_vocab_dict,
get_vocab_size,
)
class TestTokenizerSubcommand(unittest.TestCase):
"""测试TokenizerSubcommand类"""
def test_name_attribute(self):
self.assertEqual(TokenizerSubcommand.name, "tokenizer")
def test_subparser_init(self):
subcommand = TokenizerSubcommand()
mock_subparsers = MagicMock()
mock_parser = MagicMock()
mock_subparsers.add_parser.return_value = mock_parser
parser = subcommand.subparser_init(mock_subparsers)
# 验证解析器创建
mock_subparsers.add_parser.assert_called_once_with(
name="tokenizer",
help="Start the FastDeploy Tokenizer Server.",
description="Start the FastDeploy Tokenizer Server.",
usage="fastdeploy tokenizer [--encode/-e TEXT] [--decode/-d TEXT]",
)
self.assertEqual(parser, mock_parser)
# 验证参数添加(检查调用次数)
self.assertGreater(mock_parser.add_argument.call_count, 0)
def test_cmd_method(self):
subcommand = TokenizerSubcommand()
args = argparse.Namespace()
with patch("fastdeploy.entrypoints.cli.tokenizer.main") as mock_main:
subcommand.cmd(args)
mock_main.assert_called_once_with(args)
class TestCmdInit(unittest.TestCase):
"""测试cmd_init函数"""
def test_cmd_init_returns_list(self):
result = cmd_init()
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TokenizerSubcommand)
class TestGetVocabSize(unittest.TestCase):
"""测试get_vocab_size函数"""
def test_with_vocab_size_attribute(self):
mock_tokenizer = MagicMock()
# 使用PropertyMock来正确模拟属性
type(mock_tokenizer).vocab_size = PropertyMock(return_value=1000)
result = get_vocab_size(mock_tokenizer)
self.assertEqual(result, 1000)
def test_with_get_vocab_size_method(self):
mock_tokenizer = MagicMock()
# 确保vocab_size属性不存在让代码使用get_vocab_size方法
delattr(mock_tokenizer, "vocab_size")
mock_tokenizer.get_vocab_size.return_value = 2000
result = get_vocab_size(mock_tokenizer)
self.assertEqual(result, 2000)
def test_with_no_methods_available(self):
mock_tokenizer = MagicMock()
# 移除可能的方法
delattr(mock_tokenizer, "vocab_size")
delattr(mock_tokenizer, "get_vocab_size")
result = get_vocab_size(mock_tokenizer)
self.assertEqual(result, 100295) # 默认值
def test_exception_handling(self):
mock_tokenizer = MagicMock()
# 模拟两个方法都抛出异常
type(mock_tokenizer).vocab_size = PropertyMock(side_effect=Exception("Error"))
mock_tokenizer.get_vocab_size.side_effect = Exception("Error")
result = get_vocab_size(mock_tokenizer)
self.assertEqual(result, 0) # 默认值
class TestGetTokenizerInfo(unittest.TestCase):
"""测试get_tokenizer_info函数"""
def setUp(self):
self.mock_tokenizer = MagicMock()
type(self.mock_tokenizer).vocab_size = PropertyMock(return_value=1000)
type(self.mock_tokenizer).name_or_path = PropertyMock(return_value="test/model")
type(self.mock_tokenizer).model_max_length = PropertyMock(return_value=512)
# 特殊token
type(self.mock_tokenizer).bos_token = PropertyMock(return_value="<s>")
type(self.mock_tokenizer).eos_token = PropertyMock(return_value="</s>")
type(self.mock_tokenizer).unk_token = PropertyMock(return_value="<unk>")
type(self.mock_tokenizer).sep_token = PropertyMock(return_value="<sep>")
type(self.mock_tokenizer).pad_token = PropertyMock(return_value="<pad>")
type(self.mock_tokenizer).cls_token = PropertyMock(return_value="<cls>")
type(self.mock_tokenizer).mask_token = PropertyMock(return_value="<mask>")
# 特殊token ID
type(self.mock_tokenizer).bos_token_id = PropertyMock(return_value=1)
type(self.mock_tokenizer).eos_token_id = PropertyMock(return_value=2)
type(self.mock_tokenizer).unk_token_id = PropertyMock(return_value=3)
type(self.mock_tokenizer).sep_token_id = PropertyMock(return_value=4)
type(self.mock_tokenizer).pad_token_id = PropertyMock(return_value=0)
type(self.mock_tokenizer).cls_token_id = PropertyMock(return_value=5)
type(self.mock_tokenizer).mask_token_id = PropertyMock(return_value=6)
def test_normal_case(self):
info = get_tokenizer_info(self.mock_tokenizer)
self.assertEqual(info["vocab_size"], 1000)
self.assertEqual(info["model_name"], "test/model")
self.assertEqual(info["tokenizer_type"], "MagicMock")
self.assertEqual(info["model_max_length"], 512)
# 检查特殊token
self.assertEqual(info["special_tokens"]["bos_token"], "<s>")
self.assertEqual(info["special_token_ids"]["bos_token_id"], 1)
def test_exception_handling(self):
# 模拟在获取属性时抛出异常
with patch("fastdeploy.entrypoints.cli.tokenizer.get_vocab_size", side_effect=Exception("Test error")):
info = get_tokenizer_info(self.mock_tokenizer)
self.assertIn("error", info)
self.assertIn("Test error", info["error"])
class TestGetVocabDict(unittest.TestCase):
"""测试get_vocab_dict函数"""
def test_vocab_attribute(self):
mock_tokenizer = MagicMock()
mock_tokenizer.vocab = {"hello": 1, "world": 2}
result = get_vocab_dict(mock_tokenizer)
self.assertEqual(result, {"hello": 1, "world": 2})
def test_get_vocab_method(self):
mock_tokenizer = MagicMock()
# 确保vocab属性不存在让代码使用get_vocab方法
delattr(mock_tokenizer, "vocab")
mock_tokenizer.get_vocab.return_value = {"a": 1, "b": 2}
result = get_vocab_dict(mock_tokenizer)
self.assertEqual(result, {"a": 1, "b": 2})
def test_tokenizer_vocab(self):
mock_tokenizer = MagicMock()
# 确保vocab和get_vocab都不存在
delattr(mock_tokenizer, "vocab")
delattr(mock_tokenizer, "get_vocab")
mock_inner_tokenizer = MagicMock()
mock_inner_tokenizer.vocab = {"x": 1}
mock_tokenizer.tokenizer = mock_inner_tokenizer
result = get_vocab_dict(mock_tokenizer)
self.assertEqual(result, {"x": 1})
def test_encoder_attribute(self):
mock_tokenizer = MagicMock()
# 确保其他属性都不存在
delattr(mock_tokenizer, "vocab")
delattr(mock_tokenizer, "get_vocab")
delattr(mock_tokenizer, "tokenizer")
mock_tokenizer.encoder = {"token": 0}
result = get_vocab_dict(mock_tokenizer)
self.assertEqual(result, {"token": 0})
def test_no_vocab_available(self):
mock_tokenizer = MagicMock()
# 移除所有可能的属性
delattr(mock_tokenizer, "vocab")
delattr(mock_tokenizer, "get_vocab")
delattr(mock_tokenizer, "tokenizer")
delattr(mock_tokenizer, "encoder")
result = get_vocab_dict(mock_tokenizer)
self.assertEqual(result, {})
def test_exception_handling(self):
mock_tokenizer = MagicMock()
# 模拟所有方法都抛出异常
mock_tokenizer.vocab = {"a": 1}
mock_tokenizer.get_vocab.side_effect = Exception("Error")
result = get_vocab_dict(mock_tokenizer)
self.assertEqual(result, {"a": 1})
class TestExportVocabulary(unittest.TestCase):
"""测试export_vocabulary函数"""
def setUp(self):
self.mock_tokenizer = MagicMock()
self.mock_tokenizer.vocab = {"hello": 1, "world": 2, "test": 3}
def test_export_json_format(self):
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, "vocab.json")
with patch("builtins.print") as mock_print:
export_vocabulary(self.mock_tokenizer, file_path)
# 验证文件内容
with open(file_path, "r", encoding="utf-8") as f:
content = json.load(f)
self.assertEqual(content, {"hello": 1, "world": 2, "test": 3})
# 验证打印输出
mock_print.assert_any_call(f"Vocabulary exported to: {file_path}")
mock_print.assert_any_call("Total tokens: 3")
def test_export_text_format(self):
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, "vocab.txt")
with patch("builtins.print"):
export_vocabulary(self.mock_tokenizer, file_path)
# 验证文件内容
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
self.assertEqual(len(lines), 3)
# 检查排序和格式 - 注意repr会添加引号
self.assertIn("1\t'hello'", lines[0])
self.assertIn("2\t'world'", lines[1])
self.assertIn("3\t'test'", lines[2])
def test_empty_vocabulary(self):
mock_tokenizer = MagicMock()
mock_tokenizer.vocab = {}
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, "vocab.json")
with patch("builtins.print") as mock_print:
export_vocabulary(mock_tokenizer, file_path)
mock_print.assert_any_call("Warning: Could not retrieve vocabulary from tokenizer")
def test_directory_creation(self):
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, "newdir", "vocab.json")
with patch("builtins.print"):
export_vocabulary(self.mock_tokenizer, file_path)
# 验证目录被创建
self.assertTrue(os.path.exists(os.path.dirname(file_path)))
def test_exception_handling(self):
with patch("pathlib.Path.mkdir", side_effect=Exception("Permission denied")):
with patch("builtins.print") as mock_print:
export_vocabulary(self.mock_tokenizer, "/invalid/path/vocab.json")
mock_print.assert_any_call("Error exporting vocabulary: Permission denied")
class TestTokenizerCli(unittest.TestCase):
def setUp(self):
model_name = model_path()
self.test_args = argparse.Namespace()
self.test_args.model_name_or_path = model_name
self.test_args.encode = "Hello, world!"
self.test_args.decode = "[1, 2, 3]"
self.test_args.vocab_size = True
self.test_args.info = True
self.tmpdir = tempfile.TemporaryDirectory()
self.test_args.vocab_export = os.path.join(self.tmpdir.name, "vocab.txt")
def tearDown(self):
self.tmpdir.cleanup()
def test_main(self):
main(self.test_args)
self.assertTrue(os.path.exists(self.test_args.vocab_export))
if __name__ == "__main__":
unittest.main(verbosity=2)