Files
FastDeploy/tests/input/test_paddleocr_vl_processor.py
Haonan Luo 2c281e617c
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Update Unit Test for PaddleOCR-VL (#4802)
* fix paddleocr prefix cache bug

* add test for paddleocr_vl

* disable prefix-caching in ocr

* add test for paddleocr_vl

* Fix top_p for rejection sampling

* add test for ocr processor; fix top_p for rejection sampling

* add test for ocr processor; fix top_p for rejection sampling

* add test for ocr processor; fix top_p for rejection sampling

* add test for ocr processor; fix top_p for rejection sampling

* add test for ocr processor; fix top_p for rejection sampling

---------

Co-authored-by: ming1753 <ideaminghp@163.com>
Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
2025-11-04 22:40:15 +08:00

1147 lines
48 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pickle
import unittest
from unittest.mock import ANY, MagicMock, patch
import numpy as np
import zmq
from PIL import Image
from fastdeploy.input.paddleocr_vl_processor.image_processor import (
ImageProcessor,
smart_resize,
)
from fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor import (
PaddleOCRVLProcessor,
)
from fastdeploy.input.paddleocr_vl_processor.process import DataProcessor
from fastdeploy.input.paddleocr_vl_processor.process_video import sample_frames
MODULE_PATH = "fastdeploy.input.paddleocr_vl_processor.process"
class TestProcessVideo(unittest.TestCase):
def setUp(self):
self.metadata = {"num_of_frame": 100, "fps": 25}
self.frame_factor = 4
self.min_frames = 8
self.max_frames = 32
def test_sample_with_num_frames(self):
"""测试使用num_frames参数采样来自用户的原始测试"""
num_frames = 16
indices = sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=self.max_frames,
num_frames=num_frames,
fps=0, # 确保 fps 不>0
metadata=self.metadata,
)
self.assertEqual(len(indices), 16)
self.assertEqual(indices[0], 0)
self.assertEqual(indices[-1], 93)
np.testing.assert_array_equal(indices, np.arange(0, 100, 100 / 16).astype(np.int32))
def test_error_num_frames_exceeds_total(self):
"""测试 num_frames 超过总帧数的异常(来自用户的原始测试)"""
with self.assertRaises(ValueError) as context:
sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=self.max_frames,
num_frames=200, # 超过总帧数100
fps=0,
metadata=self.metadata,
)
self.assertIn("exceeds", str(context.exception))
def test_error_mutual_exclusion(self):
"""新增:测试 num_frames 和 fps 互斥"""
with self.assertRaises(ValueError) as context:
sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=self.max_frames,
num_frames=16, # > 0
fps=10, # > 0
metadata=self.metadata,
)
self.assertIn("mutually exclusive", str(context.exception))
def test_error_fps_without_metadata(self):
"""新增:测试 fps > 0 但 metadata 为 None"""
with self.assertRaises(TypeError) as context:
sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=self.max_frames,
num_frames=0,
fps=10,
metadata=None, # 缺失
)
# 验证是预期的 TypeError
self.assertIn("'NoneType' object is not subscriptable", str(context.exception))
def test_num_frames_rounding(self):
"""新增:测试 num_frames 向 frame_factor 舍入"""
num_frames = 17 # 不是 4 的倍数
# 逻辑: round(17 / 4) * 4 = round(4.25) * 4 = 4 * 4 = 16
indices = sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=self.max_frames,
num_frames=num_frames,
fps=0,
metadata=self.metadata,
)
# 应舍入到 16
self.assertEqual(len(indices), 16)
def test_sample_with_fps_basic(self):
"""新增:测试使用 fps 采样(基本路径,被 max_frames 限制)"""
# 逻辑: num_frames_calc = 100 / 25 * 10 = 40
# num_frames_clamped = min(max(40, 8), 32) = 32
# num_frames_factored = floor(32 / 4) * 4 = 32
indices = sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=self.max_frames,
num_frames=0,
fps=10,
metadata=self.metadata,
)
# 应被 max_frames=32 限制
self.assertEqual(len(indices), 32)
self.assertEqual(indices[-1], 96)
def test_sample_with_fps_hits_min_frames(self):
"""新增:测试使用 fps 采样(被 min_frames 限制)"""
# 逻辑: num_frames_calc = 100 / 25 * 1 = 4
# num_frames_clamped = min(max(4, 8), 32) = 8
# num_frames_factored = floor(8 / 4) * 4 = 8
indices = sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=self.max_frames,
num_frames=0,
fps=1,
metadata=self.metadata,
)
# 应被 min_frames=8 限制
self.assertEqual(len(indices), 8)
self.assertEqual(indices[-1], 87)
def test_sample_with_fps_hits_total_frames(self):
"""新增:测试使用 fps 采样(被 total_num_frames 限制)"""
local_max_frames = 200
# 逻辑: num_frames_calc = 100 / 25 * 50 = 200
# num_frames_clamped = min(min(max(200, 8), 200), 100) = 100
# num_frames_factored = floor(100 / 4) * 4 = 100
indices = sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=local_max_frames,
num_frames=0,
fps=50,
metadata=self.metadata,
)
# 应被 total_num_frames=100 限制
self.assertEqual(len(indices), 100)
self.assertEqual(indices[-1], 99) # 采样所有帧
def test_no_sampling(self):
"""新增测试不采样fps=0, num_frames=0"""
indices = sample_frames(
frame_factor=self.frame_factor,
min_frames=self.min_frames,
max_frames=self.max_frames,
num_frames=0,
fps=0,
metadata=self.metadata,
)
# 应返回所有帧
self.assertEqual(len(indices), self.metadata["num_of_frame"])
self.assertEqual(len(indices), 100)
self.assertEqual(indices[-1], 99)
np.testing.assert_array_equal(indices, np.arange(0, 100).astype(np.int32))
class Test_DataProcessor(unittest.TestCase):
"""
针对 process.py 中 DataProcessor 类的单元测试。
"""
def setUp(self):
# 1. 手动启动 Patcher
patcher1 = patch(f"{MODULE_PATH}.AutoTokenizer.from_pretrained")
patcher2 = patch(f"{MODULE_PATH}.ImageProcessor.from_pretrained")
patcher_zmq_context = patch(f"{MODULE_PATH}.zmq.Context")
self.mock_auto_tokenizer_constructor = patcher1.start()
self.mock_image_processor_constructor = patcher2.start()
self.mock_zmq_context_constructor = patcher_zmq_context.start()
self.addCleanup(patcher1.stop)
self.addCleanup(patcher2.stop)
self.addCleanup(patcher_zmq_context.stop)
# 2. 创建模拟对象
self.mock_tokenizer = MagicMock()
self.mock_image_processor = MagicMock()
self.mock_zmq_context = MagicMock()
self.mock_zmq_socket = MagicMock()
# 3. 配置 from_pretrained 和 zmq
self.mock_auto_tokenizer_constructor.return_value = self.mock_tokenizer
self.mock_image_processor_constructor.return_value = self.mock_image_processor
self.mock_zmq_context_constructor.return_value = self.mock_zmq_context
self.mock_zmq_context.socket.return_value = self.mock_zmq_socket
# 4. 配置模拟对象的属性和方法
self._configure_mocks()
# 5. 实例化 DataProcessor (默认不启用 cache)
self.processor = DataProcessor(model_path="dummy_model_path")
self._configure_processor_ids()
# 6. 准备测试用的虚拟数据
self.dummy_image = Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255))
self.dummy_video_frames = np.uint8(np.random.rand(16, 224, 224, 3) * 255)
self.dummy_video_data = "path/to/dummy_video.mp4"
self.dummy_processed_image_cache = (
np.random.rand(64, 3, 14, 14).astype(np.float32),
{"thw": (1, 8, 8), "fps": 0},
)
self.dummy_processed_video_cache = (
np.random.rand(256, 3, 14, 14).astype(np.float32),
{"thw": (4, 8, 8), "fps": 30},
)
def _configure_mocks(self):
def mock_convert_tokens_to_ids(tokens):
if tokens == "<|IMAGE_PLACEHOLDER|>":
return 100
if tokens == "<|video_pad|>":
return 101
if tokens == "<|IMAGE_START|>":
return 102
if isinstance(tokens, list):
if tokens == ["Hello", "world"]:
return [983, 984]
if tokens == ["Prompt", "text"]:
return [606, 511]
if tokens == ["Prompt", "", "text"]:
return [606, 511] # 模拟 "Prompt text".split()
return [hash(t) % 1000 for t in tokens]
return hash(tokens) % 1000
self.mock_tokenizer.convert_tokens_to_ids.side_effect = mock_convert_tokens_to_ids
self.mock_tokenizer.tokenize.side_effect = lambda s: s.split()
self.mock_tokenizer.ignored_index = -100
self.mock_tokenizer.chat_template = "dummy_template_string"
self.mock_image_processor.merge_size = 2
self.mock_image_processor.temporal_patch_size = 1
def _configure_processor_ids(self):
self.processor.image_token_id = 100
self.processor.video_token_id = 101
self.processor.image_patch_id = 100
self.processor.vision_start_id = 102
def _get_init_outputs(self):
return {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"fps": [],
"mm_positions": [],
"mm_hashes": [],
"vit_seqlen": [],
"vit_position_ids": [],
}
def test_init(self):
"""测试 DataProcessor 的初始化"""
self.mock_auto_tokenizer_constructor.assert_called_with("dummy_model_path", padding_side="left", use_fast=True)
self.mock_image_processor_constructor.assert_called_with("dummy_model_path")
self.assertEqual(self.processor.image_token, "<|IMAGE_PLACEHOLDER|>")
self.assertEqual(self.processor.video_token_id, 101)
def test_compute_text_positions(self):
"""测试 _compute_text_positions 纯函数"""
pos_ids = self.processor._compute_text_positions(start_pos=5, num_tokens=3)
expected = np.array([[5, 6, 7], [5, 6, 7], [5, 6, 7]])
np.testing.assert_array_equal(pos_ids, expected)
def test_compute_vision_positions(self):
"""测试 _compute_vision_positions 纯函数"""
pos_ids = self.processor._compute_vision_positions(start_pos=10, t=2, h=4, w=4, second_per_grid_t=1.0)
self.assertEqual(pos_ids.shape, (3, 8))
expected_t = np.array([0, 0, 0, 0, 2, 2, 2, 2])
expected_h = np.array([0, 0, 1, 1, 0, 0, 1, 1])
expected_w = np.array([0, 1, 0, 1, 0, 1, 0, 1])
expected = np.stack([expected_t, expected_h, expected_w]) + 10
np.testing.assert_array_equal(pos_ids, expected)
@patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2})
def test_add_text(self):
"""测试 _add_text 辅助函数"""
outputs = self._get_init_outputs()
self.mock_tokenizer.tokenize.return_value = ["Hello", "world"]
self.mock_tokenizer.convert_tokens_to_ids.side_effect = None
self.mock_tokenizer.convert_tokens_to_ids.return_value = [10, 11]
self.processor._add_text("Hello world", outputs)
self.assertEqual(outputs["input_ids"], [10, 11])
self.assertEqual(outputs["token_type_ids"], [0, 0])
self.assertEqual(outputs["cur_position"], 2)
@patch(f"{MODULE_PATH}.MultimodalHasher.hash_features", return_value="dummy_hash_123")
@patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2})
def test_add_image_autohash(self, mock_hasher):
"""测试 _add_image 辅助函数 (自动哈希)"""
outputs = self._get_init_outputs()
outputs["cur_position"] = 5
num_patches_hw = 8 * 8
num_tokens = 16
mock_preprocess_return = {
"pixel_values": np.random.rand(num_patches_hw, 3, 14, 14),
"grid_thw": np.array([1, 8, 8]),
}
self.mock_image_processor.preprocess.return_value = mock_preprocess_return
self.processor._add_image(self.dummy_image, outputs, uuid=None)
self.assertEqual(len(outputs["input_ids"]), num_tokens)
self.assertEqual(outputs["num_input_image_tokens"], num_tokens)
mock_hasher.assert_called_once_with(mock_preprocess_return["pixel_values"])
self.assertEqual(outputs["mm_hashes"][0], "dummy_hash_123")
self.assertEqual(outputs["cur_position"], 9)
@patch(f"{MODULE_PATH}.MultimodalHasher.hash_features")
@patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2})
def test_add_video_with_uuid(self, mock_hasher):
"""测试 _add_video 辅助函数 (使用 uuid)"""
outputs = self._get_init_outputs()
outputs["cur_position"] = 10
meta = {"fps": 30}
num_patches_total = 256
num_tokens = 64
mock_preprocess_return = {
"pixel_values": np.random.rand(num_patches_total, 3, 14, 14),
"image_grid_thw": np.array([4, 8, 8]),
}
self.mock_image_processor.preprocess.return_value = mock_preprocess_return
self.processor._add_video(self.dummy_video_frames, meta, outputs, uuid="custom_vid_uuid")
self.assertEqual(len(outputs["input_ids"]), num_tokens)
self.assertEqual(outputs["token_type_ids"], [2] * num_tokens)
mock_hasher.assert_not_called()
self.assertEqual(outputs["mm_hashes"][0], "custom_vid_uuid")
self.assertEqual(outputs["image_type_ids"], [1, 1, 1, 1])
@patch.object(DataProcessor, "_add_text", MagicMock())
@patch.object(DataProcessor, "_add_image", MagicMock())
@patch.object(DataProcessor, "_add_video", MagicMock())
@patch.object(DataProcessor, "_load_and_process_video")
def test_text2ids_parsing(self, mock_load_video):
"""测试 text2ids 的解析和分支逻辑"""
mock_load_video.return_value = (self.dummy_video_frames, {"fps": 30})
text = "Text1 <|IMAGE_PLACEHOLDER|> Text2 <|video_pad|> Text3"
images = [self.dummy_image]
videos = [self.dummy_video_data]
image_uuid = ["img_uuid_1"]
video_uuid = ["vid_uuid_1"]
outputs = self.processor.text2ids(text, images, videos, image_uuid, video_uuid)
self.processor._add_text.assert_any_call("Text1 ", outputs)
self.processor._add_image.assert_called_once_with(self.dummy_image, outputs, "img_uuid_1")
self.processor._add_video.assert_called_once_with(self.dummy_video_frames, {"fps": 30}, outputs, "vid_uuid_1")
@patch(f"{MODULE_PATH}.parse_chat_messages")
@patch.object(DataProcessor, "text2ids", return_value="final_output")
def test_request2ids(self, mock_text2ids, mock_parse_chat):
"""测试 request2ids 的 chat 模板逻辑"""
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "image", "data": self.dummy_image, "uuid": "img1"},
],
}
]
request = {"messages": messages, "add_generation_prompt": True}
mock_parse_chat.return_value = messages
parsed_prompt = "User: Hello <|IMAGE_PLACEHOLDER|> Assistant:"
self.mock_tokenizer.apply_chat_template.return_value = parsed_prompt
result = self.processor.request2ids(request)
self.mock_tokenizer.apply_chat_template.assert_called_once()
mock_text2ids.assert_called_once_with(parsed_prompt, [self.dummy_image], [], ["img1"], [])
self.assertEqual(result, "final_output")
@patch(f"{MODULE_PATH}.sample_frames")
@patch(f"{MODULE_PATH}.read_video_decord")
def test_load_and_process_video(self, mock_read_video, mock_sample_frames):
"""测试 _load_and_process_video 的帧采样逻辑"""
mock_reader = MagicMock()
mock_reader.__getitem__.return_value.asnumpy.return_value = np.random.randint(
0, 255, (100, 100, 3), dtype=np.uint8
)
mock_meta = {"num_of_frame": 100, "duration": 10.0, "fps": 10.0}
mock_read_video.return_value = (mock_reader, mock_meta, None)
mock_sample_frames.return_value = [0, 10, 20, 30, 40]
self.processor.fps = 1
frames, meta = self.processor._load_and_process_video("dummy_url", {"min_frames": 2, "max_frames": 10})
mock_sample_frames.assert_called_once_with(
frame_factor=ANY,
min_frames=2,
max_frames=10,
metadata=mock_meta,
fps=self.processor.fps,
num_frames=self.processor.target_frames,
)
self.assertEqual(frames.shape, (5, 100, 100, 3))
self.assertEqual(meta["fps"], 1)
def test_init_with_external_tokenizer(self):
"""新增:测试使用外部传入的 tokenizer 初始化"""
self.mock_auto_tokenizer_constructor.reset_mock()
external_tokenizer = MagicMock()
processor = DataProcessor(model_path="dummy", tokenizer=external_tokenizer)
self.mock_auto_tokenizer_constructor.assert_not_called()
self.assertIs(processor.tokenizer, external_tokenizer)
def test_add_text_empty(self):
"""新增:测试 _add_text 传入空字符串"""
outputs = self._get_init_outputs()
self.processor._add_text("", outputs)
self.assertEqual(outputs["input_ids"], [])
self.assertEqual(outputs["cur_position"], 0)
@patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0})
def test_add_text_pre_tokenized(self):
"""新增:测试 _add_text 传入已 tokenized 的 IDs"""
outputs = self._get_init_outputs()
token_ids = [10, 11, 12]
self.processor._add_text(token_ids, outputs)
self.mock_tokenizer.tokenize.assert_not_called()
self.assertEqual(outputs["input_ids"], [10, 11, 12])
self.assertEqual(outputs["token_type_ids"], [0, 0, 0])
self.assertEqual(outputs["cur_position"], 3)
@patch(f"{MODULE_PATH}.MultimodalHasher.hash_features", return_value="dummy_hash_456")
@patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2})
def test_add_video_no_uuid(self, mock_hasher):
"""新增:测试 _add_video 在 uuid 为 None 时自动哈希"""
outputs = self._get_init_outputs()
meta = {"fps": 30}
mock_preprocess_return = {
"pixel_values": np.random.rand(256, 3, 14, 14),
"image_grid_thw": np.array([4, 8, 8]),
}
self.mock_image_processor.preprocess.return_value = mock_preprocess_return
self.processor._add_video(self.dummy_video_frames, meta, outputs, uuid=None)
mock_hasher.assert_called_once_with(mock_preprocess_return["pixel_values"])
self.assertEqual(outputs["mm_hashes"][0], "dummy_hash_456")
@patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2})
def test_add_processed_image(self):
"""新增:测试 _add_processed_image 处理缓存数据"""
outputs = self._get_init_outputs()
outputs["cur_position"] = 3
self.processor._add_processed_image(self.dummy_processed_image_cache, outputs, "cached_img_uuid")
num_tokens = 16
self.assertEqual(len(outputs["input_ids"]), num_tokens)
self.assertEqual(outputs["input_ids"][0], self.processor.image_patch_id)
np.testing.assert_array_equal(outputs["images"][0], self.dummy_processed_image_cache[0])
self.assertEqual(outputs["mm_hashes"][0], "cached_img_uuid")
self.assertEqual(outputs["cur_position"], 7)
@patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2})
def test_add_processed_video(self):
"""新增:测试 _add_processed_video 处理缓存数据"""
outputs = self._get_init_outputs()
outputs["cur_position"] = 5
self.processor._add_processed_video(self.dummy_processed_video_cache, outputs, "cached_vid_uuid")
num_tokens = 64
t, h, w = self.dummy_processed_video_cache[1]["thw"]
self.assertEqual(len(outputs["input_ids"]), num_tokens)
self.assertEqual(outputs["token_type_ids"], [2] * num_tokens)
np.testing.assert_array_equal(outputs["images"][0], self.dummy_processed_video_cache[0])
self.assertEqual(outputs["mm_hashes"][0], "cached_vid_uuid")
self.assertEqual(outputs["image_type_ids"], [1] * t)
self.assertGreater(outputs["cur_position"], 5)
def test_text2ids_with_processed_data(self):
"""新增:测试 text2ids 调用 _add_processed_image 和 _add_processed_video"""
with (
patch.object(self.processor, "_add_processed_image") as mock_add_proc_img,
patch.object(self.processor, "_add_processed_video") as mock_add_proc_vid,
):
text = "<|IMAGE_PLACEHOLDER|><|video_pad|>"
images = [self.dummy_processed_image_cache]
videos = [self.dummy_processed_video_cache]
image_uuid = ["img1"]
video_uuid = ["vid1"]
self.processor.text2ids(text, images, videos, image_uuid, video_uuid)
mock_add_proc_img.assert_called_once_with(self.dummy_processed_image_cache, ANY, "img1")
mock_add_proc_vid.assert_called_once_with(self.dummy_processed_video_cache, ANY, "vid1")
@patch(f"{MODULE_PATH}.sample_frames")
@patch(f"{MODULE_PATH}.read_video_decord")
def test_load_and_process_video_no_sampling(self, mock_read_video, mock_sample_frames):
"""新增:测试 _load_and_process_video 不采样fps=-1"""
mock_reader = MagicMock()
mock_reader.__getitem__.return_value.asnumpy.return_value = np.random.randint(
0, 255, (100, 100, 3), dtype=np.uint8
)
mock_meta = {"num_of_frame": 10, "duration": 1.0, "fps": 10.0}
mock_read_video.return_value = (mock_reader, mock_meta, None)
self.processor.fps = -1
self.processor.target_frames = -1
frames, meta = self.processor._load_and_process_video("dummy_url", {})
mock_sample_frames.assert_not_called()
self.assertEqual(frames.shape, (10, 100, 100, 3))
self.assertEqual(meta["num_of_frame"], 10)
def test_get_processor_cache(self):
"""新增:测试 get_processor_cache (zmq)"""
hashes = ["hash1", "hash2"]
expected_items = ["item1", "item2"]
mock_resp = pickle.dumps(expected_items)
self.mock_zmq_socket.recv_multipart.return_value = (b"", mock_resp)
items = self.processor.get_processor_cache(self.mock_zmq_socket, hashes)
self.mock_zmq_socket.send_multipart.assert_called_once_with([b"", pickle.dumps(hashes)])
self.assertEqual(items, expected_items)
def test_update_processor_cache(self):
"""新增:测试 update_processor_cache (zmq)"""
hashes = ["hash1"]
items = ["item1"]
self.processor.update_processor_cache(self.mock_zmq_socket, hashes, items)
expected_req = pickle.dumps((hashes, items))
self.mock_zmq_socket.send_multipart.assert_called_once_with([b"", expected_req])
def test_apply_chat_template(self):
"""新增:测试 apply_chat_template 核心逻辑"""
request = {"messages": ["msg1"], "add_generation_prompt": True, "request_id": "req123"}
self.mock_tokenizer.apply_chat_template.return_value = "Prompt <|IMAGE_PLACEHOLDER|> text"
self.mock_tokenizer.tokenize.return_value = ["Prompt", "text"]
self.mock_tokenizer.convert_tokens_to_ids.side_effect = None
self.mock_tokenizer.convert_tokens_to_ids.return_value = [10, 11]
token_ids = self.processor.apply_chat_template(request)
self.assertEqual(token_ids, [10, 11])
self.assertEqual(request["text_after_process"], "Prompt <|IMAGE_PLACEHOLDER|> text")
self.mock_tokenizer.tokenize.assert_called_with("Prompt text")
def test_apply_chat_template_raises_error(self):
"""新增:测试 apply_chat_template 在模板不存在时引发 ValueError"""
self.mock_tokenizer.chat_template = None
with self.assertRaises(ValueError) as context:
self.processor.apply_chat_template({"messages": []})
self.assertIn("does not support chat_template", str(context.exception))
@patch(f"{MODULE_PATH}.parse_chat_messages")
def test_request2ids_cache_miss_raises_error(self, mock_parse_chat):
"""新增:测试 request2ids 在缓存关闭时缺少数据引发 ValueError"""
messages = [{"role": "user", "content": [{"type": "image", "uuid": "img1"}]}]
request = {"messages": messages}
mock_parse_chat.return_value = messages
with self.assertRaises(ValueError) as context:
self.processor.request2ids(request)
self.assertIn("Missing items cannot be retrieved without processor cache.", str(context.exception))
@patch(f"{MODULE_PATH}.DataProcessor.get_processor_cache")
@patch(f"{MODULE_PATH}.DataProcessor.update_processor_cache")
@patch(f"{MODULE_PATH}.DataProcessor.text2ids")
@patch(f"{MODULE_PATH}.parse_chat_messages")
def test_request2ids_cache_hit_and_update(self, mock_parse_chat, mock_text2ids, mock_update_cache, mock_get_cache):
"""新增:测试 request2ids 缓存命中和缓存更新"""
self.processor = DataProcessor(model_path="dummy_model_path", enable_processor_cache=True)
self._configure_processor_ids()
messages = [
{
"role": "user",
"content": [
{"type": "image", "uuid": "img_cache_hit"},
{"type": "image", "data": self.dummy_image, "uuid": "img_to_update"},
],
}
]
request = {"messages": messages}
mock_parse_chat.return_value = messages
mock_get_cache.return_value = [self.dummy_processed_image_cache]
mock_text2ids_output = {
"grid_thw": [(1, 8, 8), (1, 8, 8)],
"fps": [0, 0],
"mm_hashes": ["img_cache_hit", "img_to_update"],
"images": [self.dummy_processed_image_cache[0], self.dummy_processed_image_cache[0]],
}
mock_text2ids.return_value = mock_text2ids_output
self.mock_tokenizer.apply_chat_template.return_value = "<|IMAGE_PLACEHOLDER|><|IMAGE_PLACEHOLDER|>"
self.processor.request2ids(request)
self.mock_zmq_context.socket.assert_called_with(zmq.DEALER)
mock_get_cache.assert_called_once_with(self.mock_zmq_socket, ["img_cache_hit"])
parsed_images = mock_text2ids.call_args[0][1]
self.assertIs(parsed_images[0], self.dummy_processed_image_cache)
self.assertIs(parsed_images[1], self.dummy_image)
expected_hash_to_cache = ["img_to_update"]
expected_item_to_cache = (self.dummy_processed_image_cache[0], {"thw": (1, 8, 8), "fps": 0})
mock_update_cache.assert_called_once()
self.assertEqual(mock_update_cache.call_args[0][1], expected_hash_to_cache)
self.assertEqual(mock_update_cache.call_args[0][2][0][1], expected_item_to_cache[1])
np.testing.assert_array_equal(mock_update_cache.call_args[0][2][0][0], expected_item_to_cache[0])
@patch(f"{MODULE_PATH}.DataProcessor.text2ids")
@patch(f"{MODULE_PATH}.parse_chat_messages")
def test_request2ids_unsupported_type(self, mock_parse_chat, mock_text2ids):
"""新增:测试 request2ids 静默忽略不支持的类型"""
messages = [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}, {"type": "audio", "data": "...", "uuid": "audio1"}],
}
]
request = {"messages": messages}
mock_parse_chat.return_value = messages
self.mock_tokenizer.apply_chat_template.return_value = "User: Hello "
self.processor.request2ids(request)
mock_text2ids.assert_called_once()
call_args = mock_text2ids.call_args[0]
self.assertEqual(call_args[1], []) # images
self.assertEqual(call_args[2], []) # videos
self.assertEqual(call_args[3], []) # image_uuid
self.assertEqual(call_args[4], []) # video_uuid
class TestPaddleOCR_VL_ImageProcessor(unittest.TestCase):
def setUp(self):
# 初始化默认参数
self.default_params = {
"do_resize": True,
"resample": 3,
"do_rescale": True,
"rescale_factor": 1 / 255,
"do_normalize": True,
"image_mean": [0.48145466, 0.4578275, 0.40821073],
"image_std": [0.26862954, 0.26130258, 0.27577711],
"do_convert_rgb": True,
"min_pixels": 28 * 28 * 130,
"max_pixels": 28 * 28 * 1280,
"patch_size": 14,
"temporal_patch_size": 1,
"merge_size": 2,
}
# 创建测试图像
self.test_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
def test_initialization(self):
"""测试初始化参数是否正确设置"""
processor = ImageProcessor(**self.default_params)
for param, value in self.default_params.items():
self.assertEqual(getattr(processor, param), value)
def test_smart_resize(self):
"""测试智能调整图像大小功能"""
# 测试正常尺寸调整
h, w = smart_resize(224, 224, factor=28)
self.assertEqual(h % 28, 0)
self.assertEqual(w % 28, 0)
# 测试小尺寸调整
h, w = smart_resize(20, 20, factor=28)
self.assertGreaterEqual(h, 28)
self.assertGreaterEqual(w, 28)
# 测试超大尺寸调整
h, w = smart_resize(2000, 2000, factor=28)
self.assertLess(h * w, 28 * 28 * 1280)
def test_preprocess_single_image(self):
"""测试单张图像预处理流程"""
processor = ImageProcessor(**self.default_params)
# 测试正常预处理
result = processor.preprocess(self.test_image)
self.assertIn("pixel_values", result)
self.assertIn("grid_thw", result)
self.assertEqual(result["pixel_values"].ndim, 4) # [N, C, H, W]
# 测试关闭某些预处理步骤
result = processor.preprocess(self.test_image, do_resize=False, do_normalize=False)
self.assertIn("pixel_values", result)
def test_preprocess_batch_images(self):
"""测试批量图像预处理"""
processor = ImageProcessor(**self.default_params)
batch_images = [self.test_image, self.test_image]
result = processor.preprocess(batch_images)
expected_shape = 1152
self.assertEqual(result["pixel_values"].shape[0], expected_shape)
def test_invalid_input(self):
"""测试无效输入处理"""
processor = ImageProcessor(**self.default_params)
# 测试无效图像
with self.assertRaises(ValueError):
processor.preprocess("invalid_image")
# 测试视频输入(暂不支持)
with self.assertRaises(NotImplementedError):
processor.preprocess(self.test_image, videos=["video"])
def test_from_pretrained(self):
"""测试从预训练模型加载配置"""
with patch("builtins.open", unittest.mock.mock_open(read_data='{"do_resize": false}')) as mock_file:
processor = ImageProcessor.from_pretrained("dummy_path")
self.assertFalse(processor.do_resize)
mock_file.assert_called_once()
class TestPaddleOCRVLProcessor(unittest.TestCase):
def setUp(self):
# 创建 PaddleOCRVLProcessor 实例的模拟对象
with patch.object(PaddleOCRVLProcessor, "__init__", return_value=None):
self.processor = PaddleOCRVLProcessor("model_path")
# 设置必要的属性
self.processor.tokenizer = MagicMock()
self.processor.tokenizer.eos_token_id = 1
self.processor.processor = MagicMock()
self.processor.limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1}
self.processor.eos_token_ids = [1]
# 模拟 _apply_default_parameters
def mock_apply_default_parameters(request_or_dict):
if isinstance(request_or_dict, dict):
if "top_p" not in request_or_dict:
request_or_dict["top_p"] = 0.9
return request_or_dict
if not hasattr(request_or_dict, "top_p"):
request_or_dict.top_p = 0.9
return request_or_dict
self.processor._apply_default_parameters = mock_apply_default_parameters
# 模拟 pack_outputs
def mock_pack_outputs(outputs):
# 简化 position_ids 的处理
position_ids_list = outputs["position_ids"]
if not position_ids_list:
position_ids = np.array([], dtype=np.int64)
elif isinstance(position_ids_list[0], list):
position_ids = np.array(position_ids_list, dtype=np.int64)
else:
position_ids = np.concatenate(position_ids_list, axis=1, dtype=np.int64)
if position_ids.ndim == 1:
position_ids = position_ids.reshape(1, -1)
# 源码的 pack_outputs 会 transpose
position_ids = position_ids.transpose(1, 0)
return {
"input_ids": np.array(outputs["input_ids"], dtype=np.int64),
"token_type_ids": np.array(outputs["token_type_ids"], dtype=np.int64),
"position_ids": position_ids,
"images": np.vstack(outputs["images"]) if outputs.get("images") else None,
"grid_thw": np.vstack(outputs["grid_thw"]) if outputs.get("grid_thw") else None,
"image_type_ids": np.array(outputs["image_type_ids"]) if outputs.get("image_type_ids") else None,
}
self.processor.pack_outputs = mock_pack_outputs
self.processor.np = np
# 模拟 _SAMPLING_EPS 常量
self.processor._SAMPLING_EPS = 1e-5
# 模拟 processor 返回 (position_ids 必须是 2D array 的 list)
self.processor.processor.text2ids.return_value = {
"input_ids": [1, 2, 3],
"token_type_ids": [0, 0, 0],
"position_ids": [np.array([[0, 1, 2]], dtype=np.int64)], # 修正
"images": ["image_feature"],
"grid_thw": ["grid_feature"],
"image_type_ids": [0],
"cur_position": 3,
}
self.processor.processor.request2ids.return_value = {
"input_ids": [1, 2, 3],
"token_type_ids": [0, 0, 0],
"position_ids": [np.array([[0, 1, 2]], dtype=np.int64)], # 修正
"images": ["image_feature"],
"grid_thw": ["grid_feature"],
"image_type_ids": [0],
"cur_position": 3,
}
# 模拟 _compute_text_positions 方法 (返回 2D array)
self.processor.processor._compute_text_positions = lambda pos, num: np.array(
[list(range(pos, pos + num))], dtype=np.int64
)
# 模拟 update_stop_seq
self.processor.update_stop_seq = MagicMock(return_value=([[99, 98]], [2]))
# 模拟 pack_outputs 需要的属性
self.processor.processor.image_token_id = 100
self.processor.processor.video_token_id = 101
def test_process_request_dict_basic(self):
"""测试基本请求处理功能"""
request = {
"prompt": "test prompt",
"multimodal_data": {"image": ["image1"]},
"metadata": {"generated_token_ids": []},
}
result = self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(result["prompt_token_ids"], [1, 2, 3])
self.assertEqual(result["prompt_token_ids_len"], 3)
self.assertTrue("multimodal_inputs" in result)
def test_process_request_dict_with_messages(self):
"""测试 messages 格式的请求处理"""
request = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}, {"type": "image_url", "url": "image1"}],
}
],
"metadata": {"generated_token_ids": []},
}
result = self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(result["prompt_token_ids"], [1, 2, 3])
self.assertTrue("multimodal_inputs" in result)
def test_process_request_dict_with_max_len(self):
"""测试最大长度限制功能"""
request = {
"prompt": "test prompt",
"multimodal_data": {"image": ["image1"]},
"metadata": {"generated_token_ids": []},
}
# 模拟 processor 返回长序列
self.processor.processor.text2ids.return_value = {
"input_ids": list(range(100)),
"token_type_ids": [0] * 100,
"position_ids": [np.array([list(range(100))], dtype=np.int64)],
"images": ["image_feature"],
"grid_thw": ["grid_feature"],
"image_type_ids": [0],
"cur_position": 100,
}
max_model_len = 50
result = self.processor.process_request_dict(request, max_model_len)
# 验证是否截断到 max_model_len - 1
self.assertEqual(len(result["prompt_token_ids"]), max_model_len - 1)
self.assertEqual(result["prompt_token_ids"], list(range(49)))
# 验证原始输入长度确实超过了限制
self.assertGreater(len(self.processor.processor.text2ids.return_value["input_ids"]), max_model_len)
def test_parse_processor_kwargs(self):
"""测试处理器参数解析"""
valid_kwargs = {"video_max_frames": 10, "video_min_frames": 1}
result = self.processor._parse_processor_kwargs(valid_kwargs)
self.assertEqual(result, valid_kwargs)
# 测试无效参数
invalid_kwargs = {"video_max_frames": "invalid"}
with patch(
"fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.data_processor_logger"
) as mock_logger:
result = self.processor._parse_processor_kwargs(invalid_kwargs)
self.assertEqual(result, {})
# 确认警告已被记录
mock_logger.warning.assert_called()
def test_parse_limits(self):
"""测试输入限制解析"""
custom_limits = {"image": 2, "video": 3}
result = self.processor._parse_limits(custom_limits)
self.assertEqual(result["image"], 2)
self.assertEqual(result["video"], 3)
self.assertEqual(result["audio"], 1) # 默认值
def test_check_mm_limits(self):
"""测试多模态输入限制检查 (dict path)"""
# 测试不超限
item = {"image": ["image1"], "video": ["video1"]}
self.processor._check_mm_limits(item)
# 测试超限
item_exceeded = {"image": ["image1", "image2"], "video": ["video1"]}
with self.assertRaises(ValueError):
self.processor._check_mm_limits(item_exceeded)
def test_process_request_wrapper(self):
"""测试 process_request 封装方法"""
# 1. 模拟输入 Request 对象
request_obj = MagicMock()
request_dict = {
"prompt": "test prompt",
"multimodal_data": {"image": ["image1"]},
"metadata": {"generated_token_ids": []},
}
request_obj.to_dict.return_value = request_dict
# 2. patch 'Request'
patch_target = "fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.Request"
with patch(patch_target) as MockRequestCls:
# 3. 模拟 Request.from_dict 返回一个 mock 对象
final_mock_request = MagicMock()
MockRequestCls.from_dict.return_value = final_mock_request
# 4. Call function
result_request = self.processor.process_request(request_obj, max_model_len=512)
# 5. 检查 *传递给* Request.from_dict 的字典
self.assertTrue(MockRequestCls.from_dict.called)
# 获取传递给 from_dict 的第一个位置参数
processed_task_dict = MockRequestCls.from_dict.call_args[0][0]
# 这个断言现在应该能通过了
self.assertEqual(processed_task_dict["prompt_token_ids"], [1, 2, 3])
# 6. 检查返回的是否是最终的 Request 对象
self.assertIs(result_request, final_mock_request)
def test_parse_processor_kwargs_invalid_type(self):
"""测试 _parse_processor_kwargs 传入非字典类型"""
invalid_input = ["video_max_frames", 10]
with patch(
"fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.data_processor_logger"
) as mock_logger:
result = self.processor._parse_processor_kwargs(invalid_input)
self.assertEqual(result, {}) # 触发
mock_logger.warning.assert_called()
def test_parse_limits_invalid_type(self):
"""测试 _parse_limits 传入非字典类型"""
invalid_input = ["image", 2]
with patch(
"fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.data_processor_logger"
) as mock_logger:
result = self.processor._parse_limits(invalid_input)
# 应返回默认值
self.assertEqual(result, {"image": 1, "video": 1, "audio": 1})
mock_logger.warning.assert_called()
def test_check_mm_limits_messages_path(self):
"""测试 _check_mm_limits (messages path)"""
messages = [
{"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "image_url", "url": "image1"}]}
]
self.processor._check_mm_limits(messages) # 不应抛出异常
def test_check_mm_limits_messages_exceeded(self):
"""测试 _check_mm_limits (messages path) 超限"""
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "image_url", "url": "image1"},
{"type": "image_url", "url": "image2"}, # 超过限制 1
],
}
]
with self.assertRaises(ValueError):
self.processor._check_mm_limits(messages)
def test_process_request_dict_no_prompt_or_messages(self):
"""测试当请求既没有 prompt 也没有 messages 时抛出异常"""
request = {"metadata": {"generated_token_ids": []}}
with self.assertRaises(ValueError):
self.processor.process_request_dict(request, max_model_len=512)
def test_process_request_dict_with_continuation(self):
"""测试续写逻辑 (metadata 包含 generated_token_ids)"""
request = {
"prompt": "test prompt",
"multimodal_data": {"image": ["image1"]},
"metadata": {"generated_token_ids": [10, 11, 12]}, # 已生成的 token
}
result = self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(result["prompt_token_ids"], [1, 2, 3, 10, 11, 12])
self.assertEqual(result["prompt_token_ids_len"], 6)
def test_process_request_dict_with_stop_sequences(self):
"""测试 stop_sequences 处理"""
request = {"prompt": "test prompt", "stop": ["stop1", "stop2"], "metadata": {"generated_token_ids": []}}
result = self.processor.process_request_dict(request, max_model_len=512)
# 验证 update_stop_seq 被调用
self.processor.update_stop_seq.assert_called_with(["stop1", "stop2"])
# 验证结果被设置到 request 中
self.assertEqual(result["stop_token_ids"], [[99, 98]])
self.assertEqual(result["stop_seqs_len"], [2])
def test_process_request_dict_default_max_tokens(self):
"""测试默认 max_tokens 计算"""
request = {"prompt": "test prompt", "metadata": {"generated_token_ids": []}} # 长度为 3
max_model_len = 10
result = self.processor.process_request_dict(request, max_model_len)
self.assertEqual(result["max_tokens"], 7)
def test_process_request_dict_top_p_clamping(self):
"""测试 top_p 值被修正 (clamping)"""
request = {
"prompt": "test prompt",
"top_p": 0.0, # 低于 _SAMPLING_EPS
"metadata": {"generated_token_ids": []},
}
result = self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(result["top_p"], self.processor._SAMPLING_EPS)
def test_append_generated_tokens(self):
"""直接测试 append_generated_tokens 辅助函数"""
# : position_ids 必须是 [2D array]
multimodal_inputs = {
"input_ids": [1, 2, 3],
"token_type_ids": [0, 0, 0],
"position_ids": [np.array([[0, 1, 2]], dtype=np.int64)],
"cur_position": 3,
}
generated_token_ids = [10, 11]
# 调用 append_generated_tokens (它是 PaddleOCRVLProcessor 的方法)
PaddleOCRVLProcessor.append_generated_tokens(self.processor, multimodal_inputs, generated_token_ids)
self.assertEqual(multimodal_inputs["input_ids"], [1, 2, 3, 10, 11])
self.assertEqual(multimodal_inputs["token_type_ids"], [0, 0, 0, 0, 0])
# : 检查 position_ids 是否为 [np.array(...), np.array(...)]
self.assertEqual(len(multimodal_inputs["position_ids"]), 2)
self.assertTrue(np.array_equal(multimodal_inputs["position_ids"][0], np.array([[0, 1, 2]], dtype=np.int64)))
self.assertTrue(np.array_equal(multimodal_inputs["position_ids"][1], np.array([[3, 4]], dtype=np.int64)))
self.assertEqual(multimodal_inputs["cur_position"], 5)
def test_pack_outputs_real_no_images(self):
"""测试真实的 pack_outputs 方法 (无图像)"""
outputs = {
"input_ids": [1, 2, 3],
"token_type_ids": [0, 0, 0],
# : position_ids 必须是 [2D array]
"position_ids": [np.array([[0, 1, 2]], dtype=np.int64)],
"images": [], # 空列表
"grid_thw": [],
"image_type_ids": [],
"cur_position": 3,
}
# 调用真实的类方法,而不是 setUp 中 mock 的实例方法
result = PaddleOCRVLProcessor.pack_outputs(self.processor, outputs)
self.assertIsNone(result["images"])
self.assertIsNone(result["grid_thw"])
self.assertIsNone(result["image_type_ids"])
self.assertTrue(np.array_equal(result["input_ids"], np.array([1, 2, 3], dtype=np.int64)))
# 验证 position_ids 被 concatenate 和 transpose
# input: [array([[0, 1, 2]])] -> concat: array([[0, 1, 2]]) (shape 1,3) -> transpose: array([[0], [1], [2]]) (shape 3,1)
self.assertTrue(np.array_equal(result["position_ids"], np.array([[0], [1], [2]], dtype=np.int64)))
self.assertEqual(result["image_patch_id"], 100)
self.assertEqual(result["video_patch_id"], 101)
def test_pack_outputs_real_with_images(self):
"""测试真实的 pack_outputs 方法 (有图像)"""
image_feature = np.array([[0.1, 0.2]])
grid_feature = np.array([[1, 2, 3]])
outputs = {
"input_ids": [1, 2, 3],
"token_type_ids": [0, 0, 0],
# : position_ids 必须是 [2D array]
"position_ids": [np.array([[0, 1, 2]], dtype=np.int64)],
"images": [image_feature],
"grid_thw": [grid_feature],
"image_type_ids": [0],
"cur_position": 3,
}
result = PaddleOCRVLProcessor.pack_outputs(self.processor, outputs)
self.assertTrue(np.array_equal(result["images"], image_feature))
self.assertTrue(np.array_equal(result["grid_thw"], grid_feature))
self.assertTrue(np.array_equal(result["image_type_ids"], np.array([0])))
self.assertTrue(np.array_equal(result["position_ids"], np.array([[0], [1], [2]], dtype=np.int64)))
if __name__ == "__main__":
unittest.main()