From 2c281e617c5da1d99328bb4a8e9fd9723c237631 Mon Sep 17 00:00:00 2001 From: Haonan Luo <466107905@qq.com> Date: Tue, 4 Nov 2025 22:40:15 +0800 Subject: [PATCH] Update Unit Test for PaddleOCR-VL (#4802) * fix paddleocr prefix cache bug * add test for paddleocr_vl * disable prefix-caching in ocr * add test for paddleocr_vl * Fix top_p for rejection sampling * add test for ocr processor; fix top_p for rejection sampling * add test for ocr processor; fix top_p for rejection sampling * add test for ocr processor; fix top_p for rejection sampling * add test for ocr processor; fix top_p for rejection sampling * add test for ocr processor; fix top_p for rejection sampling --------- Co-authored-by: ming1753 Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com> --- .../ernie4_5_vl_processor.py | 5 + .../paddleocr_vl_processor.py | 6 +- tests/e2e/test_paddleocr_vl_serving.py | 17 +- tests/input/test_paddleocr_vl_processor.py | 1146 +++++++++++++++++ 4 files changed, 1160 insertions(+), 14 deletions(-) create mode 100644 tests/input/test_paddleocr_vl_processor.py diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 862709757..d86eb86c5 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -26,6 +26,8 @@ from fastdeploy.utils import data_processor_logger from .process import DataProcessor +_SAMPLING_EPS = 1e-5 + class Ernie4_5_VLProcessor(Ernie4_5Processor): """The processor class for ERNIE MoE VL models.""" @@ -268,6 +270,9 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor): request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1) data_processor_logger.info(f"Processed request {request}") + if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS: + request["top_p"] = _SAMPLING_EPS + return request def append_completion_tokens(self, multimodal_inputs, completion_token_ids): diff --git a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py index 2e9e680c0..a5335fd0c 100644 --- a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py +++ b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py @@ -22,6 +22,8 @@ from fastdeploy.utils import data_processor_logger from .process import DataProcessor +_SAMPLING_EPS = 1e-5 + class PaddleOCRVLProcessor(TextProcessor): """ @@ -61,7 +63,6 @@ class PaddleOCRVLProcessor(TextProcessor): tool_parser_obj: Tool parser instance """ super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj) - data_processor_logger.info(f"model_name_or_path: {model_name_or_path}") processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs) self.processor = DataProcessor( @@ -252,6 +253,9 @@ class PaddleOCRVLProcessor(TextProcessor): if request.get("max_tokens") is None: request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token + if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS: + request["top_p"] = _SAMPLING_EPS + return request def append_generated_tokens(self, multimodal_inputs, generated_token_ids): diff --git a/tests/e2e/test_paddleocr_vl_serving.py b/tests/e2e/test_paddleocr_vl_serving.py index 997bbddf8..5d69ae68f 100644 --- a/tests/e2e/test_paddleocr_vl_serving.py +++ b/tests/e2e/test_paddleocr_vl_serving.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import shutil import signal @@ -99,7 +98,6 @@ def setup_and_run_server(): model_path = "./PaddleOCR-VL-0.9B" log_path = "server.log" - limit_mm_str = json.dumps({"image": 100, "video": 100}) cmd = [ sys.executable, @@ -109,8 +107,6 @@ def setup_and_run_server(): model_path, "--port", str(FD_API_PORT), - "--tensor-parallel-size", - "2", "--engine-worker-queue-port", str(FD_ENGINE_QUEUE_PORT), "--metrics-port", @@ -119,18 +115,13 @@ def setup_and_run_server(): str(FD_CACHE_QUEUE_PORT), "--enable-mm", "--max-model-len", - "32768", + "16384", "--max-num-batched-tokens", - "384", + "16384", "--max-num-seqs", "128", - "--limit-mm-per-prompt", - limit_mm_str, - "--enable-chunked-prefill", - "--kv-cache-ratio", - "0.71", - "--quantization", - "wint4", + "--gpu-memory-utilization", + "0.9", "--graph-optimization-config", '{"graph_opt_level":0, "use_cudagraph":true}', ] diff --git a/tests/input/test_paddleocr_vl_processor.py b/tests/input/test_paddleocr_vl_processor.py new file mode 100644 index 000000000..62b58db26 --- /dev/null +++ b/tests/input/test_paddleocr_vl_processor.py @@ -0,0 +1,1146 @@ +import pickle +import unittest +from unittest.mock import ANY, MagicMock, patch + +import numpy as np +import zmq +from PIL import Image + +from fastdeploy.input.paddleocr_vl_processor.image_processor import ( + ImageProcessor, + smart_resize, +) +from fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor import ( + PaddleOCRVLProcessor, +) +from fastdeploy.input.paddleocr_vl_processor.process import DataProcessor +from fastdeploy.input.paddleocr_vl_processor.process_video import sample_frames + +MODULE_PATH = "fastdeploy.input.paddleocr_vl_processor.process" + + +class TestProcessVideo(unittest.TestCase): + def setUp(self): + self.metadata = {"num_of_frame": 100, "fps": 25} + self.frame_factor = 4 + self.min_frames = 8 + self.max_frames = 32 + + def test_sample_with_num_frames(self): + """测试使用num_frames参数采样(来自用户的原始测试)""" + num_frames = 16 + indices = sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=self.max_frames, + num_frames=num_frames, + fps=0, # 确保 fps 不>0 + metadata=self.metadata, + ) + self.assertEqual(len(indices), 16) + self.assertEqual(indices[0], 0) + self.assertEqual(indices[-1], 93) + np.testing.assert_array_equal(indices, np.arange(0, 100, 100 / 16).astype(np.int32)) + + def test_error_num_frames_exceeds_total(self): + """测试 num_frames 超过总帧数的异常(来自用户的原始测试)""" + with self.assertRaises(ValueError) as context: + sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=self.max_frames, + num_frames=200, # 超过总帧数100 + fps=0, + metadata=self.metadata, + ) + self.assertIn("exceeds", str(context.exception)) + + def test_error_mutual_exclusion(self): + """新增:测试 num_frames 和 fps 互斥""" + with self.assertRaises(ValueError) as context: + sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=self.max_frames, + num_frames=16, # > 0 + fps=10, # > 0 + metadata=self.metadata, + ) + self.assertIn("mutually exclusive", str(context.exception)) + + def test_error_fps_without_metadata(self): + """新增:测试 fps > 0 但 metadata 为 None""" + with self.assertRaises(TypeError) as context: + sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=self.max_frames, + num_frames=0, + fps=10, + metadata=None, # 缺失 + ) + # 验证是预期的 TypeError + self.assertIn("'NoneType' object is not subscriptable", str(context.exception)) + + def test_num_frames_rounding(self): + """新增:测试 num_frames 向 frame_factor 舍入""" + num_frames = 17 # 不是 4 的倍数 + # 逻辑: round(17 / 4) * 4 = round(4.25) * 4 = 4 * 4 = 16 + indices = sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=self.max_frames, + num_frames=num_frames, + fps=0, + metadata=self.metadata, + ) + # 应舍入到 16 + self.assertEqual(len(indices), 16) + + def test_sample_with_fps_basic(self): + """新增:测试使用 fps 采样(基本路径,被 max_frames 限制)""" + # 逻辑: num_frames_calc = 100 / 25 * 10 = 40 + # num_frames_clamped = min(max(40, 8), 32) = 32 + # num_frames_factored = floor(32 / 4) * 4 = 32 + indices = sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=self.max_frames, + num_frames=0, + fps=10, + metadata=self.metadata, + ) + # 应被 max_frames=32 限制 + self.assertEqual(len(indices), 32) + self.assertEqual(indices[-1], 96) + + def test_sample_with_fps_hits_min_frames(self): + """新增:测试使用 fps 采样(被 min_frames 限制)""" + # 逻辑: num_frames_calc = 100 / 25 * 1 = 4 + # num_frames_clamped = min(max(4, 8), 32) = 8 + # num_frames_factored = floor(8 / 4) * 4 = 8 + indices = sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=self.max_frames, + num_frames=0, + fps=1, + metadata=self.metadata, + ) + # 应被 min_frames=8 限制 + self.assertEqual(len(indices), 8) + self.assertEqual(indices[-1], 87) + + def test_sample_with_fps_hits_total_frames(self): + """新增:测试使用 fps 采样(被 total_num_frames 限制)""" + local_max_frames = 200 + + # 逻辑: num_frames_calc = 100 / 25 * 50 = 200 + # num_frames_clamped = min(min(max(200, 8), 200), 100) = 100 + # num_frames_factored = floor(100 / 4) * 4 = 100 + indices = sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=local_max_frames, + num_frames=0, + fps=50, + metadata=self.metadata, + ) + # 应被 total_num_frames=100 限制 + self.assertEqual(len(indices), 100) + self.assertEqual(indices[-1], 99) # 采样所有帧 + + def test_no_sampling(self): + """新增:测试不采样(fps=0, num_frames=0)""" + indices = sample_frames( + frame_factor=self.frame_factor, + min_frames=self.min_frames, + max_frames=self.max_frames, + num_frames=0, + fps=0, + metadata=self.metadata, + ) + # 应返回所有帧 + self.assertEqual(len(indices), self.metadata["num_of_frame"]) + self.assertEqual(len(indices), 100) + self.assertEqual(indices[-1], 99) + np.testing.assert_array_equal(indices, np.arange(0, 100).astype(np.int32)) + + +class Test_DataProcessor(unittest.TestCase): + """ + 针对 process.py 中 DataProcessor 类的单元测试。 + """ + + def setUp(self): + + # 1. 手动启动 Patcher + patcher1 = patch(f"{MODULE_PATH}.AutoTokenizer.from_pretrained") + patcher2 = patch(f"{MODULE_PATH}.ImageProcessor.from_pretrained") + patcher_zmq_context = patch(f"{MODULE_PATH}.zmq.Context") + + self.mock_auto_tokenizer_constructor = patcher1.start() + self.mock_image_processor_constructor = patcher2.start() + self.mock_zmq_context_constructor = patcher_zmq_context.start() + + self.addCleanup(patcher1.stop) + self.addCleanup(patcher2.stop) + self.addCleanup(patcher_zmq_context.stop) + + # 2. 创建模拟对象 + self.mock_tokenizer = MagicMock() + self.mock_image_processor = MagicMock() + self.mock_zmq_context = MagicMock() + self.mock_zmq_socket = MagicMock() + + # 3. 配置 from_pretrained 和 zmq + self.mock_auto_tokenizer_constructor.return_value = self.mock_tokenizer + self.mock_image_processor_constructor.return_value = self.mock_image_processor + self.mock_zmq_context_constructor.return_value = self.mock_zmq_context + self.mock_zmq_context.socket.return_value = self.mock_zmq_socket + + # 4. 配置模拟对象的属性和方法 + self._configure_mocks() + + # 5. 实例化 DataProcessor (默认不启用 cache) + self.processor = DataProcessor(model_path="dummy_model_path") + self._configure_processor_ids() + + # 6. 准备测试用的虚拟数据 + self.dummy_image = Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255)) + self.dummy_video_frames = np.uint8(np.random.rand(16, 224, 224, 3) * 255) + self.dummy_video_data = "path/to/dummy_video.mp4" + self.dummy_processed_image_cache = ( + np.random.rand(64, 3, 14, 14).astype(np.float32), + {"thw": (1, 8, 8), "fps": 0}, + ) + self.dummy_processed_video_cache = ( + np.random.rand(256, 3, 14, 14).astype(np.float32), + {"thw": (4, 8, 8), "fps": 30}, + ) + + def _configure_mocks(self): + def mock_convert_tokens_to_ids(tokens): + if tokens == "<|IMAGE_PLACEHOLDER|>": + return 100 + if tokens == "<|video_pad|>": + return 101 + if tokens == "<|IMAGE_START|>": + return 102 + if isinstance(tokens, list): + if tokens == ["Hello", "world"]: + return [983, 984] + if tokens == ["Prompt", "text"]: + return [606, 511] + if tokens == ["Prompt", "", "text"]: + return [606, 511] # 模拟 "Prompt text".split() + return [hash(t) % 1000 for t in tokens] + return hash(tokens) % 1000 + + self.mock_tokenizer.convert_tokens_to_ids.side_effect = mock_convert_tokens_to_ids + self.mock_tokenizer.tokenize.side_effect = lambda s: s.split() + self.mock_tokenizer.ignored_index = -100 + self.mock_tokenizer.chat_template = "dummy_template_string" + + self.mock_image_processor.merge_size = 2 + self.mock_image_processor.temporal_patch_size = 1 + + def _configure_processor_ids(self): + self.processor.image_token_id = 100 + self.processor.video_token_id = 101 + self.processor.image_patch_id = 100 + self.processor.vision_start_id = 102 + + def _get_init_outputs(self): + return { + "input_ids": [], + "token_type_ids": [], + "position_ids": [], + "images": [], + "grid_thw": [], + "image_type_ids": [], + "labels": [], + "cur_position": 0, + "video_cnt": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + "fps": [], + "mm_positions": [], + "mm_hashes": [], + "vit_seqlen": [], + "vit_position_ids": [], + } + + def test_init(self): + """测试 DataProcessor 的初始化""" + self.mock_auto_tokenizer_constructor.assert_called_with("dummy_model_path", padding_side="left", use_fast=True) + self.mock_image_processor_constructor.assert_called_with("dummy_model_path") + self.assertEqual(self.processor.image_token, "<|IMAGE_PLACEHOLDER|>") + self.assertEqual(self.processor.video_token_id, 101) + + def test_compute_text_positions(self): + """测试 _compute_text_positions 纯函数""" + pos_ids = self.processor._compute_text_positions(start_pos=5, num_tokens=3) + expected = np.array([[5, 6, 7], [5, 6, 7], [5, 6, 7]]) + np.testing.assert_array_equal(pos_ids, expected) + + def test_compute_vision_positions(self): + """测试 _compute_vision_positions 纯函数""" + pos_ids = self.processor._compute_vision_positions(start_pos=10, t=2, h=4, w=4, second_per_grid_t=1.0) + self.assertEqual(pos_ids.shape, (3, 8)) + expected_t = np.array([0, 0, 0, 0, 2, 2, 2, 2]) + expected_h = np.array([0, 0, 1, 1, 0, 0, 1, 1]) + expected_w = np.array([0, 1, 0, 1, 0, 1, 0, 1]) + expected = np.stack([expected_t, expected_h, expected_w]) + 10 + np.testing.assert_array_equal(pos_ids, expected) + + @patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2}) + def test_add_text(self): + """测试 _add_text 辅助函数""" + outputs = self._get_init_outputs() + self.mock_tokenizer.tokenize.return_value = ["Hello", "world"] + self.mock_tokenizer.convert_tokens_to_ids.side_effect = None + self.mock_tokenizer.convert_tokens_to_ids.return_value = [10, 11] + + self.processor._add_text("Hello world", outputs) + + self.assertEqual(outputs["input_ids"], [10, 11]) + self.assertEqual(outputs["token_type_ids"], [0, 0]) + self.assertEqual(outputs["cur_position"], 2) + + @patch(f"{MODULE_PATH}.MultimodalHasher.hash_features", return_value="dummy_hash_123") + @patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2}) + def test_add_image_autohash(self, mock_hasher): + """测试 _add_image 辅助函数 (自动哈希)""" + outputs = self._get_init_outputs() + outputs["cur_position"] = 5 + + num_patches_hw = 8 * 8 + num_tokens = 16 + mock_preprocess_return = { + "pixel_values": np.random.rand(num_patches_hw, 3, 14, 14), + "grid_thw": np.array([1, 8, 8]), + } + self.mock_image_processor.preprocess.return_value = mock_preprocess_return + + self.processor._add_image(self.dummy_image, outputs, uuid=None) + + self.assertEqual(len(outputs["input_ids"]), num_tokens) + self.assertEqual(outputs["num_input_image_tokens"], num_tokens) + mock_hasher.assert_called_once_with(mock_preprocess_return["pixel_values"]) + self.assertEqual(outputs["mm_hashes"][0], "dummy_hash_123") + self.assertEqual(outputs["cur_position"], 9) + + @patch(f"{MODULE_PATH}.MultimodalHasher.hash_features") + @patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2}) + def test_add_video_with_uuid(self, mock_hasher): + """测试 _add_video 辅助函数 (使用 uuid)""" + outputs = self._get_init_outputs() + outputs["cur_position"] = 10 + meta = {"fps": 30} + + num_patches_total = 256 + num_tokens = 64 + + mock_preprocess_return = { + "pixel_values": np.random.rand(num_patches_total, 3, 14, 14), + "image_grid_thw": np.array([4, 8, 8]), + } + self.mock_image_processor.preprocess.return_value = mock_preprocess_return + + self.processor._add_video(self.dummy_video_frames, meta, outputs, uuid="custom_vid_uuid") + + self.assertEqual(len(outputs["input_ids"]), num_tokens) + self.assertEqual(outputs["token_type_ids"], [2] * num_tokens) + mock_hasher.assert_not_called() + self.assertEqual(outputs["mm_hashes"][0], "custom_vid_uuid") + self.assertEqual(outputs["image_type_ids"], [1, 1, 1, 1]) + + @patch.object(DataProcessor, "_add_text", MagicMock()) + @patch.object(DataProcessor, "_add_image", MagicMock()) + @patch.object(DataProcessor, "_add_video", MagicMock()) + @patch.object(DataProcessor, "_load_and_process_video") + def test_text2ids_parsing(self, mock_load_video): + """测试 text2ids 的解析和分支逻辑""" + mock_load_video.return_value = (self.dummy_video_frames, {"fps": 30}) + text = "Text1 <|IMAGE_PLACEHOLDER|> Text2 <|video_pad|> Text3" + images = [self.dummy_image] + videos = [self.dummy_video_data] + image_uuid = ["img_uuid_1"] + video_uuid = ["vid_uuid_1"] + + outputs = self.processor.text2ids(text, images, videos, image_uuid, video_uuid) + + self.processor._add_text.assert_any_call("Text1 ", outputs) + self.processor._add_image.assert_called_once_with(self.dummy_image, outputs, "img_uuid_1") + self.processor._add_video.assert_called_once_with(self.dummy_video_frames, {"fps": 30}, outputs, "vid_uuid_1") + + @patch(f"{MODULE_PATH}.parse_chat_messages") + @patch.object(DataProcessor, "text2ids", return_value="final_output") + def test_request2ids(self, mock_text2ids, mock_parse_chat): + """测试 request2ids 的 chat 模板逻辑""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "image", "data": self.dummy_image, "uuid": "img1"}, + ], + } + ] + request = {"messages": messages, "add_generation_prompt": True} + mock_parse_chat.return_value = messages + parsed_prompt = "User: Hello <|IMAGE_PLACEHOLDER|> Assistant:" + self.mock_tokenizer.apply_chat_template.return_value = parsed_prompt + + result = self.processor.request2ids(request) + + self.mock_tokenizer.apply_chat_template.assert_called_once() + mock_text2ids.assert_called_once_with(parsed_prompt, [self.dummy_image], [], ["img1"], []) + self.assertEqual(result, "final_output") + + @patch(f"{MODULE_PATH}.sample_frames") + @patch(f"{MODULE_PATH}.read_video_decord") + def test_load_and_process_video(self, mock_read_video, mock_sample_frames): + """测试 _load_and_process_video 的帧采样逻辑""" + mock_reader = MagicMock() + mock_reader.__getitem__.return_value.asnumpy.return_value = np.random.randint( + 0, 255, (100, 100, 3), dtype=np.uint8 + ) + mock_meta = {"num_of_frame": 100, "duration": 10.0, "fps": 10.0} + mock_read_video.return_value = (mock_reader, mock_meta, None) + mock_sample_frames.return_value = [0, 10, 20, 30, 40] + self.processor.fps = 1 + + frames, meta = self.processor._load_and_process_video("dummy_url", {"min_frames": 2, "max_frames": 10}) + + mock_sample_frames.assert_called_once_with( + frame_factor=ANY, + min_frames=2, + max_frames=10, + metadata=mock_meta, + fps=self.processor.fps, + num_frames=self.processor.target_frames, + ) + self.assertEqual(frames.shape, (5, 100, 100, 3)) + self.assertEqual(meta["fps"], 1) + + def test_init_with_external_tokenizer(self): + """新增:测试使用外部传入的 tokenizer 初始化""" + self.mock_auto_tokenizer_constructor.reset_mock() + + external_tokenizer = MagicMock() + processor = DataProcessor(model_path="dummy", tokenizer=external_tokenizer) + + self.mock_auto_tokenizer_constructor.assert_not_called() + self.assertIs(processor.tokenizer, external_tokenizer) + + def test_add_text_empty(self): + """新增:测试 _add_text 传入空字符串""" + outputs = self._get_init_outputs() + self.processor._add_text("", outputs) + self.assertEqual(outputs["input_ids"], []) + self.assertEqual(outputs["cur_position"], 0) + + @patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0}) + def test_add_text_pre_tokenized(self): + """新增:测试 _add_text 传入已 tokenized 的 IDs""" + outputs = self._get_init_outputs() + token_ids = [10, 11, 12] + self.processor._add_text(token_ids, outputs) + + self.mock_tokenizer.tokenize.assert_not_called() + self.assertEqual(outputs["input_ids"], [10, 11, 12]) + self.assertEqual(outputs["token_type_ids"], [0, 0, 0]) + self.assertEqual(outputs["cur_position"], 3) + + @patch(f"{MODULE_PATH}.MultimodalHasher.hash_features", return_value="dummy_hash_456") + @patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2}) + def test_add_video_no_uuid(self, mock_hasher): + """新增:测试 _add_video 在 uuid 为 None 时自动哈希""" + outputs = self._get_init_outputs() + meta = {"fps": 30} + mock_preprocess_return = { + "pixel_values": np.random.rand(256, 3, 14, 14), + "image_grid_thw": np.array([4, 8, 8]), + } + self.mock_image_processor.preprocess.return_value = mock_preprocess_return + + self.processor._add_video(self.dummy_video_frames, meta, outputs, uuid=None) + + mock_hasher.assert_called_once_with(mock_preprocess_return["pixel_values"]) + self.assertEqual(outputs["mm_hashes"][0], "dummy_hash_456") + + @patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2}) + def test_add_processed_image(self): + """新增:测试 _add_processed_image 处理缓存数据""" + outputs = self._get_init_outputs() + outputs["cur_position"] = 3 + + self.processor._add_processed_image(self.dummy_processed_image_cache, outputs, "cached_img_uuid") + + num_tokens = 16 + self.assertEqual(len(outputs["input_ids"]), num_tokens) + self.assertEqual(outputs["input_ids"][0], self.processor.image_patch_id) + + np.testing.assert_array_equal(outputs["images"][0], self.dummy_processed_image_cache[0]) + + self.assertEqual(outputs["mm_hashes"][0], "cached_img_uuid") + self.assertEqual(outputs["cur_position"], 7) + + @patch(f"{MODULE_PATH}.IDS_TYPE_FLAG", {"text": 0, "image": 1, "video": 2}) + def test_add_processed_video(self): + """新增:测试 _add_processed_video 处理缓存数据""" + outputs = self._get_init_outputs() + outputs["cur_position"] = 5 + + self.processor._add_processed_video(self.dummy_processed_video_cache, outputs, "cached_vid_uuid") + + num_tokens = 64 + t, h, w = self.dummy_processed_video_cache[1]["thw"] + + self.assertEqual(len(outputs["input_ids"]), num_tokens) + self.assertEqual(outputs["token_type_ids"], [2] * num_tokens) + + np.testing.assert_array_equal(outputs["images"][0], self.dummy_processed_video_cache[0]) + + self.assertEqual(outputs["mm_hashes"][0], "cached_vid_uuid") + self.assertEqual(outputs["image_type_ids"], [1] * t) + self.assertGreater(outputs["cur_position"], 5) + + def test_text2ids_with_processed_data(self): + """新增:测试 text2ids 调用 _add_processed_image 和 _add_processed_video""" + with ( + patch.object(self.processor, "_add_processed_image") as mock_add_proc_img, + patch.object(self.processor, "_add_processed_video") as mock_add_proc_vid, + ): + + text = "<|IMAGE_PLACEHOLDER|><|video_pad|>" + images = [self.dummy_processed_image_cache] + videos = [self.dummy_processed_video_cache] + image_uuid = ["img1"] + video_uuid = ["vid1"] + + self.processor.text2ids(text, images, videos, image_uuid, video_uuid) + + mock_add_proc_img.assert_called_once_with(self.dummy_processed_image_cache, ANY, "img1") + mock_add_proc_vid.assert_called_once_with(self.dummy_processed_video_cache, ANY, "vid1") + + @patch(f"{MODULE_PATH}.sample_frames") + @patch(f"{MODULE_PATH}.read_video_decord") + def test_load_and_process_video_no_sampling(self, mock_read_video, mock_sample_frames): + """新增:测试 _load_and_process_video 不采样(fps=-1)""" + mock_reader = MagicMock() + mock_reader.__getitem__.return_value.asnumpy.return_value = np.random.randint( + 0, 255, (100, 100, 3), dtype=np.uint8 + ) + mock_meta = {"num_of_frame": 10, "duration": 1.0, "fps": 10.0} + mock_read_video.return_value = (mock_reader, mock_meta, None) + + self.processor.fps = -1 + self.processor.target_frames = -1 + + frames, meta = self.processor._load_and_process_video("dummy_url", {}) + + mock_sample_frames.assert_not_called() + self.assertEqual(frames.shape, (10, 100, 100, 3)) + self.assertEqual(meta["num_of_frame"], 10) + + def test_get_processor_cache(self): + """新增:测试 get_processor_cache (zmq)""" + hashes = ["hash1", "hash2"] + expected_items = ["item1", "item2"] + mock_resp = pickle.dumps(expected_items) + self.mock_zmq_socket.recv_multipart.return_value = (b"", mock_resp) + + items = self.processor.get_processor_cache(self.mock_zmq_socket, hashes) + + self.mock_zmq_socket.send_multipart.assert_called_once_with([b"", pickle.dumps(hashes)]) + self.assertEqual(items, expected_items) + + def test_update_processor_cache(self): + """新增:测试 update_processor_cache (zmq)""" + hashes = ["hash1"] + items = ["item1"] + + self.processor.update_processor_cache(self.mock_zmq_socket, hashes, items) + + expected_req = pickle.dumps((hashes, items)) + self.mock_zmq_socket.send_multipart.assert_called_once_with([b"", expected_req]) + + def test_apply_chat_template(self): + """新增:测试 apply_chat_template 核心逻辑""" + request = {"messages": ["msg1"], "add_generation_prompt": True, "request_id": "req123"} + self.mock_tokenizer.apply_chat_template.return_value = "Prompt <|IMAGE_PLACEHOLDER|> text" + self.mock_tokenizer.tokenize.return_value = ["Prompt", "text"] + + self.mock_tokenizer.convert_tokens_to_ids.side_effect = None + self.mock_tokenizer.convert_tokens_to_ids.return_value = [10, 11] + + token_ids = self.processor.apply_chat_template(request) + + self.assertEqual(token_ids, [10, 11]) + self.assertEqual(request["text_after_process"], "Prompt <|IMAGE_PLACEHOLDER|> text") + + self.mock_tokenizer.tokenize.assert_called_with("Prompt text") + + def test_apply_chat_template_raises_error(self): + """新增:测试 apply_chat_template 在模板不存在时引发 ValueError""" + self.mock_tokenizer.chat_template = None + with self.assertRaises(ValueError) as context: + self.processor.apply_chat_template({"messages": []}) + self.assertIn("does not support chat_template", str(context.exception)) + + @patch(f"{MODULE_PATH}.parse_chat_messages") + def test_request2ids_cache_miss_raises_error(self, mock_parse_chat): + """新增:测试 request2ids 在缓存关闭时缺少数据引发 ValueError""" + messages = [{"role": "user", "content": [{"type": "image", "uuid": "img1"}]}] + request = {"messages": messages} + + mock_parse_chat.return_value = messages + + with self.assertRaises(ValueError) as context: + self.processor.request2ids(request) + + self.assertIn("Missing items cannot be retrieved without processor cache.", str(context.exception)) + + @patch(f"{MODULE_PATH}.DataProcessor.get_processor_cache") + @patch(f"{MODULE_PATH}.DataProcessor.update_processor_cache") + @patch(f"{MODULE_PATH}.DataProcessor.text2ids") + @patch(f"{MODULE_PATH}.parse_chat_messages") + def test_request2ids_cache_hit_and_update(self, mock_parse_chat, mock_text2ids, mock_update_cache, mock_get_cache): + """新增:测试 request2ids 缓存命中和缓存更新""" + self.processor = DataProcessor(model_path="dummy_model_path", enable_processor_cache=True) + self._configure_processor_ids() + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "uuid": "img_cache_hit"}, + {"type": "image", "data": self.dummy_image, "uuid": "img_to_update"}, + ], + } + ] + request = {"messages": messages} + + mock_parse_chat.return_value = messages + mock_get_cache.return_value = [self.dummy_processed_image_cache] + + mock_text2ids_output = { + "grid_thw": [(1, 8, 8), (1, 8, 8)], + "fps": [0, 0], + "mm_hashes": ["img_cache_hit", "img_to_update"], + "images": [self.dummy_processed_image_cache[0], self.dummy_processed_image_cache[0]], + } + mock_text2ids.return_value = mock_text2ids_output + self.mock_tokenizer.apply_chat_template.return_value = "<|IMAGE_PLACEHOLDER|><|IMAGE_PLACEHOLDER|>" + + self.processor.request2ids(request) + + self.mock_zmq_context.socket.assert_called_with(zmq.DEALER) + mock_get_cache.assert_called_once_with(self.mock_zmq_socket, ["img_cache_hit"]) + + parsed_images = mock_text2ids.call_args[0][1] + self.assertIs(parsed_images[0], self.dummy_processed_image_cache) + self.assertIs(parsed_images[1], self.dummy_image) + + expected_hash_to_cache = ["img_to_update"] + expected_item_to_cache = (self.dummy_processed_image_cache[0], {"thw": (1, 8, 8), "fps": 0}) + mock_update_cache.assert_called_once() + self.assertEqual(mock_update_cache.call_args[0][1], expected_hash_to_cache) + self.assertEqual(mock_update_cache.call_args[0][2][0][1], expected_item_to_cache[1]) + np.testing.assert_array_equal(mock_update_cache.call_args[0][2][0][0], expected_item_to_cache[0]) + + @patch(f"{MODULE_PATH}.DataProcessor.text2ids") + @patch(f"{MODULE_PATH}.parse_chat_messages") + def test_request2ids_unsupported_type(self, mock_parse_chat, mock_text2ids): + """新增:测试 request2ids 静默忽略不支持的类型""" + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}, {"type": "audio", "data": "...", "uuid": "audio1"}], + } + ] + request = {"messages": messages} + + mock_parse_chat.return_value = messages + self.mock_tokenizer.apply_chat_template.return_value = "User: Hello " + + self.processor.request2ids(request) + + mock_text2ids.assert_called_once() + call_args = mock_text2ids.call_args[0] + self.assertEqual(call_args[1], []) # images + self.assertEqual(call_args[2], []) # videos + self.assertEqual(call_args[3], []) # image_uuid + self.assertEqual(call_args[4], []) # video_uuid + + +class TestPaddleOCR_VL_ImageProcessor(unittest.TestCase): + def setUp(self): + # 初始化默认参数 + self.default_params = { + "do_resize": True, + "resample": 3, + "do_rescale": True, + "rescale_factor": 1 / 255, + "do_normalize": True, + "image_mean": [0.48145466, 0.4578275, 0.40821073], + "image_std": [0.26862954, 0.26130258, 0.27577711], + "do_convert_rgb": True, + "min_pixels": 28 * 28 * 130, + "max_pixels": 28 * 28 * 1280, + "patch_size": 14, + "temporal_patch_size": 1, + "merge_size": 2, + } + + # 创建测试图像 + self.test_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) + + def test_initialization(self): + """测试初始化参数是否正确设置""" + processor = ImageProcessor(**self.default_params) + + for param, value in self.default_params.items(): + self.assertEqual(getattr(processor, param), value) + + def test_smart_resize(self): + """测试智能调整图像大小功能""" + # 测试正常尺寸调整 + h, w = smart_resize(224, 224, factor=28) + self.assertEqual(h % 28, 0) + self.assertEqual(w % 28, 0) + + # 测试小尺寸调整 + h, w = smart_resize(20, 20, factor=28) + self.assertGreaterEqual(h, 28) + self.assertGreaterEqual(w, 28) + + # 测试超大尺寸调整 + h, w = smart_resize(2000, 2000, factor=28) + self.assertLess(h * w, 28 * 28 * 1280) + + def test_preprocess_single_image(self): + """测试单张图像预处理流程""" + processor = ImageProcessor(**self.default_params) + + # 测试正常预处理 + result = processor.preprocess(self.test_image) + self.assertIn("pixel_values", result) + self.assertIn("grid_thw", result) + self.assertEqual(result["pixel_values"].ndim, 4) # [N, C, H, W] + + # 测试关闭某些预处理步骤 + result = processor.preprocess(self.test_image, do_resize=False, do_normalize=False) + self.assertIn("pixel_values", result) + + def test_preprocess_batch_images(self): + """测试批量图像预处理""" + processor = ImageProcessor(**self.default_params) + batch_images = [self.test_image, self.test_image] + + result = processor.preprocess(batch_images) + expected_shape = 1152 + self.assertEqual(result["pixel_values"].shape[0], expected_shape) + + def test_invalid_input(self): + """测试无效输入处理""" + processor = ImageProcessor(**self.default_params) + + # 测试无效图像 + with self.assertRaises(ValueError): + processor.preprocess("invalid_image") + + # 测试视频输入(暂不支持) + with self.assertRaises(NotImplementedError): + processor.preprocess(self.test_image, videos=["video"]) + + def test_from_pretrained(self): + """测试从预训练模型加载配置""" + with patch("builtins.open", unittest.mock.mock_open(read_data='{"do_resize": false}')) as mock_file: + processor = ImageProcessor.from_pretrained("dummy_path") + self.assertFalse(processor.do_resize) + mock_file.assert_called_once() + + +class TestPaddleOCRVLProcessor(unittest.TestCase): + def setUp(self): + # 创建 PaddleOCRVLProcessor 实例的模拟对象 + with patch.object(PaddleOCRVLProcessor, "__init__", return_value=None): + self.processor = PaddleOCRVLProcessor("model_path") + + # 设置必要的属性 + self.processor.tokenizer = MagicMock() + self.processor.tokenizer.eos_token_id = 1 + self.processor.processor = MagicMock() + self.processor.limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1} + self.processor.eos_token_ids = [1] + + # 模拟 _apply_default_parameters + def mock_apply_default_parameters(request_or_dict): + if isinstance(request_or_dict, dict): + if "top_p" not in request_or_dict: + request_or_dict["top_p"] = 0.9 + return request_or_dict + + if not hasattr(request_or_dict, "top_p"): + request_or_dict.top_p = 0.9 + return request_or_dict + + self.processor._apply_default_parameters = mock_apply_default_parameters + + # 模拟 pack_outputs + def mock_pack_outputs(outputs): + # 简化 position_ids 的处理 + position_ids_list = outputs["position_ids"] + if not position_ids_list: + position_ids = np.array([], dtype=np.int64) + elif isinstance(position_ids_list[0], list): + position_ids = np.array(position_ids_list, dtype=np.int64) + else: + position_ids = np.concatenate(position_ids_list, axis=1, dtype=np.int64) + + if position_ids.ndim == 1: + position_ids = position_ids.reshape(1, -1) + + # 源码的 pack_outputs 会 transpose + position_ids = position_ids.transpose(1, 0) + + return { + "input_ids": np.array(outputs["input_ids"], dtype=np.int64), + "token_type_ids": np.array(outputs["token_type_ids"], dtype=np.int64), + "position_ids": position_ids, + "images": np.vstack(outputs["images"]) if outputs.get("images") else None, + "grid_thw": np.vstack(outputs["grid_thw"]) if outputs.get("grid_thw") else None, + "image_type_ids": np.array(outputs["image_type_ids"]) if outputs.get("image_type_ids") else None, + } + + self.processor.pack_outputs = mock_pack_outputs + self.processor.np = np + + # 模拟 _SAMPLING_EPS 常量 + self.processor._SAMPLING_EPS = 1e-5 + + # 模拟 processor 返回 (position_ids 必须是 2D array 的 list) + self.processor.processor.text2ids.return_value = { + "input_ids": [1, 2, 3], + "token_type_ids": [0, 0, 0], + "position_ids": [np.array([[0, 1, 2]], dtype=np.int64)], # 修正 + "images": ["image_feature"], + "grid_thw": ["grid_feature"], + "image_type_ids": [0], + "cur_position": 3, + } + + self.processor.processor.request2ids.return_value = { + "input_ids": [1, 2, 3], + "token_type_ids": [0, 0, 0], + "position_ids": [np.array([[0, 1, 2]], dtype=np.int64)], # 修正 + "images": ["image_feature"], + "grid_thw": ["grid_feature"], + "image_type_ids": [0], + "cur_position": 3, + } + + # 模拟 _compute_text_positions 方法 (返回 2D array) + self.processor.processor._compute_text_positions = lambda pos, num: np.array( + [list(range(pos, pos + num))], dtype=np.int64 + ) + + # 模拟 update_stop_seq + self.processor.update_stop_seq = MagicMock(return_value=([[99, 98]], [2])) + + # 模拟 pack_outputs 需要的属性 + self.processor.processor.image_token_id = 100 + self.processor.processor.video_token_id = 101 + + def test_process_request_dict_basic(self): + """测试基本请求处理功能""" + request = { + "prompt": "test prompt", + "multimodal_data": {"image": ["image1"]}, + "metadata": {"generated_token_ids": []}, + } + + result = self.processor.process_request_dict(request, max_model_len=512) + self.assertEqual(result["prompt_token_ids"], [1, 2, 3]) + self.assertEqual(result["prompt_token_ids_len"], 3) + self.assertTrue("multimodal_inputs" in result) + + def test_process_request_dict_with_messages(self): + """测试 messages 格式的请求处理""" + request = { + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}, {"type": "image_url", "url": "image1"}], + } + ], + "metadata": {"generated_token_ids": []}, + } + + result = self.processor.process_request_dict(request, max_model_len=512) + self.assertEqual(result["prompt_token_ids"], [1, 2, 3]) + self.assertTrue("multimodal_inputs" in result) + + def test_process_request_dict_with_max_len(self): + """测试最大长度限制功能""" + request = { + "prompt": "test prompt", + "multimodal_data": {"image": ["image1"]}, + "metadata": {"generated_token_ids": []}, + } + + # 模拟 processor 返回长序列 + self.processor.processor.text2ids.return_value = { + "input_ids": list(range(100)), + "token_type_ids": [0] * 100, + "position_ids": [np.array([list(range(100))], dtype=np.int64)], + "images": ["image_feature"], + "grid_thw": ["grid_feature"], + "image_type_ids": [0], + "cur_position": 100, + } + + max_model_len = 50 + result = self.processor.process_request_dict(request, max_model_len) + # 验证是否截断到 max_model_len - 1 + self.assertEqual(len(result["prompt_token_ids"]), max_model_len - 1) + self.assertEqual(result["prompt_token_ids"], list(range(49))) + # 验证原始输入长度确实超过了限制 + self.assertGreater(len(self.processor.processor.text2ids.return_value["input_ids"]), max_model_len) + + def test_parse_processor_kwargs(self): + """测试处理器参数解析""" + valid_kwargs = {"video_max_frames": 10, "video_min_frames": 1} + result = self.processor._parse_processor_kwargs(valid_kwargs) + self.assertEqual(result, valid_kwargs) + + # 测试无效参数 + invalid_kwargs = {"video_max_frames": "invalid"} + with patch( + "fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.data_processor_logger" + ) as mock_logger: + result = self.processor._parse_processor_kwargs(invalid_kwargs) + self.assertEqual(result, {}) + # 确认警告已被记录 + mock_logger.warning.assert_called() + + def test_parse_limits(self): + """测试输入限制解析""" + custom_limits = {"image": 2, "video": 3} + result = self.processor._parse_limits(custom_limits) + self.assertEqual(result["image"], 2) + self.assertEqual(result["video"], 3) + self.assertEqual(result["audio"], 1) # 默认值 + + def test_check_mm_limits(self): + """测试多模态输入限制检查 (dict path)""" + # 测试不超限 + item = {"image": ["image1"], "video": ["video1"]} + self.processor._check_mm_limits(item) + + # 测试超限 + item_exceeded = {"image": ["image1", "image2"], "video": ["video1"]} + with self.assertRaises(ValueError): + self.processor._check_mm_limits(item_exceeded) + + def test_process_request_wrapper(self): + """测试 process_request 封装方法""" + # 1. 模拟输入 Request 对象 + request_obj = MagicMock() + request_dict = { + "prompt": "test prompt", + "multimodal_data": {"image": ["image1"]}, + "metadata": {"generated_token_ids": []}, + } + request_obj.to_dict.return_value = request_dict + + # 2. patch 'Request' + patch_target = "fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.Request" + with patch(patch_target) as MockRequestCls: + + # 3. 模拟 Request.from_dict 返回一个 mock 对象 + final_mock_request = MagicMock() + MockRequestCls.from_dict.return_value = final_mock_request + + # 4. Call function + result_request = self.processor.process_request(request_obj, max_model_len=512) + + # 5. 检查 *传递给* Request.from_dict 的字典 + self.assertTrue(MockRequestCls.from_dict.called) + # 获取传递给 from_dict 的第一个位置参数 + processed_task_dict = MockRequestCls.from_dict.call_args[0][0] + + # 这个断言现在应该能通过了 + self.assertEqual(processed_task_dict["prompt_token_ids"], [1, 2, 3]) + + # 6. 检查返回的是否是最终的 Request 对象 + self.assertIs(result_request, final_mock_request) + + def test_parse_processor_kwargs_invalid_type(self): + """测试 _parse_processor_kwargs 传入非字典类型""" + invalid_input = ["video_max_frames", 10] + with patch( + "fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.data_processor_logger" + ) as mock_logger: + result = self.processor._parse_processor_kwargs(invalid_input) + self.assertEqual(result, {}) # 触发 + mock_logger.warning.assert_called() + + def test_parse_limits_invalid_type(self): + """测试 _parse_limits 传入非字典类型""" + invalid_input = ["image", 2] + with patch( + "fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.data_processor_logger" + ) as mock_logger: + result = self.processor._parse_limits(invalid_input) + # 应返回默认值 + self.assertEqual(result, {"image": 1, "video": 1, "audio": 1}) + mock_logger.warning.assert_called() + + def test_check_mm_limits_messages_path(self): + """测试 _check_mm_limits (messages path)""" + messages = [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "image_url", "url": "image1"}]} + ] + self.processor._check_mm_limits(messages) # 不应抛出异常 + + def test_check_mm_limits_messages_exceeded(self): + """测试 _check_mm_limits (messages path) 超限""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "image_url", "url": "image1"}, + {"type": "image_url", "url": "image2"}, # 超过限制 1 + ], + } + ] + with self.assertRaises(ValueError): + self.processor._check_mm_limits(messages) + + def test_process_request_dict_no_prompt_or_messages(self): + """测试当请求既没有 prompt 也没有 messages 时抛出异常""" + request = {"metadata": {"generated_token_ids": []}} + with self.assertRaises(ValueError): + self.processor.process_request_dict(request, max_model_len=512) + + def test_process_request_dict_with_continuation(self): + """测试续写逻辑 (metadata 包含 generated_token_ids)""" + request = { + "prompt": "test prompt", + "multimodal_data": {"image": ["image1"]}, + "metadata": {"generated_token_ids": [10, 11, 12]}, # 已生成的 token + } + + result = self.processor.process_request_dict(request, max_model_len=512) + self.assertEqual(result["prompt_token_ids"], [1, 2, 3, 10, 11, 12]) + self.assertEqual(result["prompt_token_ids_len"], 6) + + def test_process_request_dict_with_stop_sequences(self): + """测试 stop_sequences 处理""" + request = {"prompt": "test prompt", "stop": ["stop1", "stop2"], "metadata": {"generated_token_ids": []}} + result = self.processor.process_request_dict(request, max_model_len=512) + + # 验证 update_stop_seq 被调用 + self.processor.update_stop_seq.assert_called_with(["stop1", "stop2"]) + # 验证结果被设置到 request 中 + self.assertEqual(result["stop_token_ids"], [[99, 98]]) + self.assertEqual(result["stop_seqs_len"], [2]) + + def test_process_request_dict_default_max_tokens(self): + """测试默认 max_tokens 计算""" + request = {"prompt": "test prompt", "metadata": {"generated_token_ids": []}} # 长度为 3 + max_model_len = 10 + result = self.processor.process_request_dict(request, max_model_len) + + self.assertEqual(result["max_tokens"], 7) + + def test_process_request_dict_top_p_clamping(self): + """测试 top_p 值被修正 (clamping)""" + request = { + "prompt": "test prompt", + "top_p": 0.0, # 低于 _SAMPLING_EPS + "metadata": {"generated_token_ids": []}, + } + result = self.processor.process_request_dict(request, max_model_len=512) + self.assertEqual(result["top_p"], self.processor._SAMPLING_EPS) + + def test_append_generated_tokens(self): + """直接测试 append_generated_tokens 辅助函数""" + # : position_ids 必须是 [2D array] + multimodal_inputs = { + "input_ids": [1, 2, 3], + "token_type_ids": [0, 0, 0], + "position_ids": [np.array([[0, 1, 2]], dtype=np.int64)], + "cur_position": 3, + } + generated_token_ids = [10, 11] + + # 调用 append_generated_tokens (它是 PaddleOCRVLProcessor 的方法) + PaddleOCRVLProcessor.append_generated_tokens(self.processor, multimodal_inputs, generated_token_ids) + + self.assertEqual(multimodal_inputs["input_ids"], [1, 2, 3, 10, 11]) + self.assertEqual(multimodal_inputs["token_type_ids"], [0, 0, 0, 0, 0]) + # : 检查 position_ids 是否为 [np.array(...), np.array(...)] + self.assertEqual(len(multimodal_inputs["position_ids"]), 2) + self.assertTrue(np.array_equal(multimodal_inputs["position_ids"][0], np.array([[0, 1, 2]], dtype=np.int64))) + self.assertTrue(np.array_equal(multimodal_inputs["position_ids"][1], np.array([[3, 4]], dtype=np.int64))) + self.assertEqual(multimodal_inputs["cur_position"], 5) + + def test_pack_outputs_real_no_images(self): + """测试真实的 pack_outputs 方法 (无图像)""" + outputs = { + "input_ids": [1, 2, 3], + "token_type_ids": [0, 0, 0], + # : position_ids 必须是 [2D array] + "position_ids": [np.array([[0, 1, 2]], dtype=np.int64)], + "images": [], # 空列表 + "grid_thw": [], + "image_type_ids": [], + "cur_position": 3, + } + + # 调用真实的类方法,而不是 setUp 中 mock 的实例方法 + result = PaddleOCRVLProcessor.pack_outputs(self.processor, outputs) + + self.assertIsNone(result["images"]) + self.assertIsNone(result["grid_thw"]) + self.assertIsNone(result["image_type_ids"]) + self.assertTrue(np.array_equal(result["input_ids"], np.array([1, 2, 3], dtype=np.int64))) + # 验证 position_ids 被 concatenate 和 transpose + # input: [array([[0, 1, 2]])] -> concat: array([[0, 1, 2]]) (shape 1,3) -> transpose: array([[0], [1], [2]]) (shape 3,1) + self.assertTrue(np.array_equal(result["position_ids"], np.array([[0], [1], [2]], dtype=np.int64))) + self.assertEqual(result["image_patch_id"], 100) + self.assertEqual(result["video_patch_id"], 101) + + def test_pack_outputs_real_with_images(self): + """测试真实的 pack_outputs 方法 (有图像)""" + image_feature = np.array([[0.1, 0.2]]) + grid_feature = np.array([[1, 2, 3]]) + + outputs = { + "input_ids": [1, 2, 3], + "token_type_ids": [0, 0, 0], + # : position_ids 必须是 [2D array] + "position_ids": [np.array([[0, 1, 2]], dtype=np.int64)], + "images": [image_feature], + "grid_thw": [grid_feature], + "image_type_ids": [0], + "cur_position": 3, + } + + result = PaddleOCRVLProcessor.pack_outputs(self.processor, outputs) + + self.assertTrue(np.array_equal(result["images"], image_feature)) + self.assertTrue(np.array_equal(result["grid_thw"], grid_feature)) + self.assertTrue(np.array_equal(result["image_type_ids"], np.array([0]))) + self.assertTrue(np.array_equal(result["position_ids"], np.array([[0], [1], [2]], dtype=np.int64))) + + +if __name__ == "__main__": + unittest.main()