Files
FastDeploy/tests/input/test_ernie_vl_processor.py
kesmeey 4bd991aa17 [CI]【Hackathon 9th Sprint No.22】功能模块 fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py 单测补充 (#5263)
* test: improve ernie4_5_vl_processor.py test coverage

* update

* improve coverage

* update

* fix: correct test expectation for thinking_mode false in test_ernie_vl_processor

* remove test_process_request_dict_comprehensive test case

---------

Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com>
2025-12-15 14:00:53 +08:00

1418 lines
62 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
from PIL import Image
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
from fastdeploy.input.ernie4_5_vl_processor.image_preprocessor.image_preprocessor_adaptive import (
AdaptiveImageProcessor,
)
from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
from fastdeploy.input.utils import IDS_TYPE_FLAG
class MockReasoningParser:
def get_model_status(self, prompt_token_ids):
return "think_start"
class TestErnie4_5VLProcessorProcessResponseDictStreaming(unittest.TestCase):
def setUp(self):
# Create mock object for Ernie4_5Processor instance
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None) as mock_init:
self.processor = Ernie4_5_VLProcessor("model_path")
mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}")
# Set necessary attributes
self.processor.tokenizer = MagicMock()
self.processor.tokenizer.eos_token_id = 1
self.processor.decode_status = {"test": []}
self.processor.reasoning_end_dict = {}
self.processor.tool_parser_dict = {}
self.processor.generation_config = MagicMock()
self.processor.eos_token_ids = [1]
self.processor.reasoning_parser = MockReasoningParser()
self.processor.model_status_dict = {"test": "think_start"}
self.processor.ernie4_5_processor = MagicMock()
# Mock ids2tokens method
def mock_ids2tokens(token_ids, task_id):
return "delta_text", [2, 3], "previous_texts"
self.processor.ids2tokens = mock_ids2tokens
def mock_request2ids(request, **kwargs):
return {"input_ids": np.array([1, 2, 3]), "prompt_token_ids": [0]}
def mock_check_mm_limits(item):
pass
def mock_apply_default_parameters(request):
return request
def mock_pack_outputs(outputs):
# Ensure input_ids is numpy array if it exists
result = outputs.copy() if isinstance(outputs, dict) else outputs
if isinstance(result, dict):
if "input_ids" in result and isinstance(result["input_ids"], list):
result["input_ids"] = np.array(result["input_ids"])
if "token_type_ids" in result and isinstance(result["token_type_ids"], list):
result["token_type_ids"] = np.array(result["token_type_ids"])
if "position_ids" in result and isinstance(result["position_ids"], list):
result["position_ids"] = np.array(result["position_ids"])
return result
def mock_prompt_token_ids2outputs(request):
return {
"input_ids": np.array([1, 1, 1]),
"token_type_ids": np.array([0, 0, 0]),
"position_ids": np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]),
"images": [],
"grid_thw": [],
"image_type_ids": [],
"cur_position": 3,
}
self.processor._apply_default_parameters = mock_apply_default_parameters
self.processor._check_mm_limits = mock_check_mm_limits
self.processor.ernie4_5_processor.request2ids = mock_request2ids
self.processor.ernie4_5_processor.prompt_token_ids2outputs = mock_prompt_token_ids2outputs
self.processor.pack_outputs = mock_pack_outputs
# Mock reasoning parser
self.mock_reasoning_parser = MagicMock()
self.mock_reasoning_parser.extract_reasoning_content_streaming.return_value = None
self.processor.reasoning_parser = self.mock_reasoning_parser
# Mock tool parser
self.mock_tool_parser = MagicMock()
self.mock_tool_parser.extract_tool_calls_streaming.return_value = None
self.mock_tool_parser_obj = MagicMock()
self.mock_tool_parser_obj.return_value = self.mock_tool_parser
self.processor.tool_parser_obj = self.mock_tool_parser_obj
def test_think_status(self):
"""测试 思考机制"""
request = {
"prompt": "hello",
"request_id": "test_1",
"prompt_token_ids": [1, 2, 3],
}
self.processor.reasoning_parser = MagicMock()
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
self.processor.model_status_dict = {}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
request = {
"prompt": "hello",
"request_id": "test",
"prompt_token_ids": [1, 2, 3],
}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
def test_init(self):
"""Test __init__ method"""
with patch("fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.data_processor_logger"):
mock_dp = MagicMock()
mock_dp.image_patch_id = 1001
mock_dp.spatial_conv_size = 14
mock_dp.tokenizer = MagicMock()
mock_dp.tokenizer.pad_token_id = 0
mock_dp.eval = MagicMock()
with patch("fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.DataProcessor") as mock_dp_class:
mock_dp_class.return_value = mock_dp
with patch(
"fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.GenerationConfig"
) as mock_gen_config:
mock_gen_config.from_pretrained.return_value = MagicMock()
with patch("paddleformers.trl.llm_utils.get_eos_token_id") as mock_get_eos:
mock_get_eos.return_value = [1, 2]
# Test normal initialization
mock_reasoning_parser_class = MagicMock()
processor = Ernie4_5_VLProcessor(
"model_path",
limit_mm_per_prompt={"image": 2, "video": 1},
mm_processor_kwargs={"spatial_conv_size": 14},
reasoning_parser_obj=lambda tokenizer: mock_reasoning_parser_class,
tool_parser_obj=MagicMock(),
enable_processor_cache=True,
)
self.assertEqual(processor.image_patch_id, 1001)
self.assertEqual(processor.spatial_conv_size, 14)
self.assertIsNotNone(processor.tokenizer)
self.assertIsNotNone(processor.generation_config)
self.assertEqual(processor.eos_token_ids, [1, 2])
self.assertEqual(processor.limit_mm_per_prompt["image"], 2)
self.assertEqual(processor.limit_mm_per_prompt["video"], 1)
mock_dp.eval.assert_called_once()
# Test with generation config exception
mock_gen_config.from_pretrained.side_effect = Exception("Config not found")
processor2 = Ernie4_5_VLProcessor("model_path")
self.assertIsNone(processor2.generation_config)
# Test with reasoning_parser_obj
mock_reasoning_parser = MagicMock()
processor3 = Ernie4_5_VLProcessor(
"model_path", reasoning_parser_obj=lambda tokenizer: mock_reasoning_parser
)
self.assertIsNotNone(processor3.reasoning_parser)
def test_parse_processor_kwargs(self):
"""Test _parse_processor_kwargs with various inputs"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
processor._parse_processor_kwargs = Ernie4_5_VLProcessor._parse_processor_kwargs.__get__(
processor, Ernie4_5_VLProcessor
)
# Test with valid kwargs
valid_kwargs = {
"spatial_conv_size": 14,
"temporal_conv_size": 2,
"image_min_pixels": 1000,
"image_max_pixels": 10000,
}
result = processor._parse_processor_kwargs(valid_kwargs)
self.assertEqual(result, valid_kwargs)
# Test with invalid type (implementation catches exception and returns empty dict)
invalid_kwargs = {"spatial_conv_size": "invalid"} # Should be int
result = Ernie4_5_VLProcessor._parse_processor_kwargs(processor, invalid_kwargs)
self.assertEqual(result, {})
# Test with non-dict input (implementation catches exception and returns empty dict)
result = Ernie4_5_VLProcessor._parse_processor_kwargs(processor, "not a dict")
self.assertEqual(result, {})
# Test exception handling with None
with patch("fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.data_processor_logger"):
result = processor._parse_processor_kwargs(None)
self.assertEqual(result, {})
def test_parse_limits(self):
"""Test _parse_limits with various inputs"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
processor._parse_limits = Ernie4_5_VLProcessor._parse_limits.__get__(processor, Ernie4_5_VLProcessor)
# Test with valid limits
valid_limits = {"image": 5, "video": 3}
result = processor._parse_limits(valid_limits)
self.assertEqual(result["image"], 5)
self.assertEqual(result["video"], 3)
self.assertEqual(result["audio"], 1) # Default value
# Test with empty input (None)
result = processor._parse_limits(None)
self.assertEqual(result["image"], 1)
self.assertEqual(result["video"], 1)
self.assertEqual(result["audio"], 1)
# Test with invalid type (implementation catches exception and returns default limits)
result = Ernie4_5_VLProcessor._parse_limits(processor, "not a dict")
self.assertEqual(result["image"], 1)
self.assertEqual(result["video"], 1)
self.assertEqual(result["audio"], 1)
def test_check_mm_limits(self):
"""Test _check_mm_limits with various inputs"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
processor._check_mm_limits = Ernie4_5_VLProcessor._check_mm_limits.__get__(processor, Ernie4_5_VLProcessor)
# Test with dict input (should not raise)
processor.limit_mm_per_prompt = {"image": 2, "video": 1}
mm_data = {"image": [1, 2], "video": [1]}
processor._check_mm_limits(mm_data)
# Test with messages input (should not raise)
messages = [
{"role": "user", "content": [{"type": "image", "data": "img1"}]},
{"role": "user", "content": [{"type": "video", "data": "vid1"}]},
]
processor._check_mm_limits(messages)
# Test when limit is exceeded (should raise ValueError)
processor.limit_mm_per_prompt = {"image": 1, "video": 1}
mm_data = {"image": [1, 2, 3], "video": []} # 3 images, limit is 1
with self.assertRaises(ValueError) as context:
processor._check_mm_limits(mm_data)
self.assertIn("Too many image items", str(context.exception))
def test_process_request(self):
"""Test process_request method"""
from fastdeploy.engine.request import Request
# Mock the process_request_dict method
self.processor.process_request_dict = MagicMock()
# Create a mock Request object
mock_request = MagicMock(spec=Request)
mock_request.to_dict.return_value = {"messages": [{"role": "user", "content": "Hello"}]}
# Mock Request.from_dict to return a mock request
with patch.object(Request, "from_dict") as mock_from_dict:
mock_result_request = MagicMock(spec=Request)
mock_from_dict.return_value = mock_result_request
self.processor.process_request(mock_request, max_model_len=100, chat_template_kwargs={"key": "value"})
# Verify to_dict was called
mock_request.to_dict.assert_called_once()
# Verify process_request_dict was called
self.processor.process_request_dict.assert_called_once()
# Verify from_dict was called
mock_from_dict.assert_called_once()
def test_get_pad_id(self):
"""Test get_pad_id method"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
processor.tokenizer = MagicMock()
processor.tokenizer.pad_token_id = 100
processor.get_pad_id = Ernie4_5_VLProcessor.get_pad_id.__get__(processor, Ernie4_5_VLProcessor)
result = processor.get_pad_id()
self.assertEqual(result, 100)
def test_load_tokenizer(self):
"""Test _load_tokenizer method"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
mock_tokenizer = MagicMock()
processor.ernie4_5_processor = MagicMock()
processor.ernie4_5_processor.tokenizer = mock_tokenizer
processor._load_tokenizer = Ernie4_5_VLProcessor._load_tokenizer.__get__(processor, Ernie4_5_VLProcessor)
processor._load_tokenizer()
self.assertEqual(processor.tokenizer, mock_tokenizer)
def test_append_completion_tokens(self):
"""Test append_completion_tokens method"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
processor.append_completion_tokens = Ernie4_5_VLProcessor.append_completion_tokens.__get__(
processor, Ernie4_5_VLProcessor
)
multimodal_inputs = {
"input_ids": [1, 2, 3],
"token_type_ids": [0, 0, 0],
"position_ids": [[0, 0, 0], [1, 1, 1], [2, 2, 2]],
"cur_position": 3,
}
completion_token_ids = [10, 11, 12]
processor.append_completion_tokens(multimodal_inputs, completion_token_ids)
self.assertEqual(multimodal_inputs["input_ids"], [1, 2, 3, 10, 11, 12])
self.assertEqual(multimodal_inputs["token_type_ids"], [0, 0, 0, 0, 0, 0])
self.assertEqual(len(multimodal_inputs["position_ids"]), 6)
self.assertEqual(multimodal_inputs["cur_position"], 6)
def test_pack_outputs(self):
"""Test pack_outputs with and without images"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
processor.image_patch_id = 1001
processor.pack_outputs = Ernie4_5_VLProcessor.pack_outputs.__get__(processor, Ernie4_5_VLProcessor)
# Test with images
outs_with_images = {
"input_ids": [1, 2, 3],
"token_type_ids": [0, 0, 0],
"position_ids": [[0, 0, 0], [1, 1, 1], [2, 2, 2]],
"images": [np.array([[1, 2], [3, 4]])],
"grid_thw": [np.array([[1, 2, 2]])],
"image_type_ids": [0],
}
result = processor.pack_outputs(outs_with_images)
self.assertIsNotNone(result["images"])
self.assertIsNotNone(result["grid_thw"])
self.assertIsNotNone(result["image_type_ids"])
self.assertEqual(result["image_patch_id"], 1001)
self.assertIsInstance(result["input_ids"], np.ndarray)
self.assertIsInstance(result["token_type_ids"], np.ndarray)
self.assertIsInstance(result["position_ids"], np.ndarray)
# Test without images
outs_without_images = {
"input_ids": [1, 2, 3],
"token_type_ids": [0, 0, 0],
"position_ids": [[0, 0, 0], [1, 1, 1], [2, 2, 2]],
"images": [],
"grid_thw": [],
"image_type_ids": [],
}
result = processor.pack_outputs(outs_without_images)
self.assertIsNone(result["images"])
self.assertIsNone(result["grid_thw"])
self.assertIsNone(result["image_type_ids"])
def test_process_response_dict(self):
"""Test process_response_dict with different parameters"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
processor.process_response_dict = Ernie4_5_VLProcessor.process_response_dict.__get__(
processor, Ernie4_5_VLProcessor
)
# Test with stream=True
processor.process_response_dict_streaming = MagicMock(return_value={"text": "response"})
response_dict = {"ids": [1, 2, 3]}
result = processor.process_response_dict(response_dict, stream=True)
processor.process_response_dict_streaming.assert_called_once()
self.assertEqual(result, {"text": "response"})
# Test with stream=False
processor.process_response_dict_normal = MagicMock(return_value={"text": "response"})
response_dict = {"ids": [1, 2, 3]}
result = processor.process_response_dict(response_dict, stream=False)
processor.process_response_dict_normal.assert_called_once()
self.assertEqual(result, {"text": "response"})
def test_apply_default_parameters(self):
"""Test _apply_default_parameters with dict and object request"""
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
processor = Ernie4_5_VLProcessor("model_path")
processor.generation_config = MagicMock()
processor.generation_config.top_p = 0.8
processor.generation_config.temperature = 0.9
processor._apply_default_parameters = Ernie4_5_VLProcessor._apply_default_parameters.__get__(
processor, Ernie4_5_VLProcessor
)
# Test with dict request
request = {}
result = processor._apply_default_parameters(request)
self.assertEqual(result["top_p"], 0.8)
self.assertEqual(result["temperature"], 0.9)
# Test with object request
class MockRequest:
def __init__(self):
self.top_p = None
self.temperature = None
def get(self, key):
return getattr(self, key, None)
def set(self, key, value):
setattr(self, key, value)
request = MockRequest()
result = processor._apply_default_parameters(request)
self.assertEqual(result.top_p, 0.8)
class TestDataProcessorTargetMethods(unittest.TestCase):
def setUp(self):
self.mock_tokenizer = MagicMock(spec=Ernie4_5Tokenizer)
self.mock_tokenizer.ignored_index = -100
self.mock_tokenizer.convert_tokens_to_ids.side_effect = self._mock_convert_tokens_to_ids
self.mock_tokenizer.chat_template = "mock_template"
self.mock_tokenizer.apply_chat_template.return_value = "User: Hello<|image@placeholder|>"
# Mock encode method for _add_text
self.mock_tokenizer.encode = MagicMock(return_value={"input_ids": [1, 2, 3]})
def mock_load_tokenizer(dp_instance):
dp_instance.tokenizer = self.mock_tokenizer
with patch.object(DataProcessor, "_load_tokenizer", side_effect=mock_load_tokenizer, autospec=True):
with patch.object(AdaptiveImageProcessor, "from_pretrained") as mock_image_preprocessor:
mock_image_preprocessor.return_value = MagicMock()
self.data_processor = DataProcessor(
tokenizer_name="mock_tokenizer",
image_preprocessor_name="mock_image_preprocessor",
enable_processor_cache=False,
)
self.data_processor.image_patch_id = 1001
self.data_processor.image_start_id = 1002
self.data_processor.image_end_id = 1003
self.data_processor.video_start_id = 1004
self.data_processor.video_end_id = 1005
self.data_processor.role_prefixes = {"user": "User: ", "assistant": "Assistant: "}
self.data_processor.enable_processor_cache = False
# Note: extract_mm_items is not mocked by default, only when needed
self.data_processor.extract_mm_items = MagicMock(return_value=([], [], [], [], None, [], []))
def _restore_real_extract_mm_items(self):
"""Helper method to restore real extract_mm_items method for testing"""
from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
original_extract_mm_items = DataProcessor.extract_mm_items
self.data_processor.extract_mm_items = original_extract_mm_items.__get__(self.data_processor, DataProcessor)
def _mock_convert_tokens_to_ids(self, token):
token_id_map = {
"<|begin_of_sentence|>": 101,
"<|end_of_sentence|>": 102,
"</s>": 103,
"<|IMAGE_PLACEHOLDER|>": 1001,
"<|IMAGE_START|>": 1002,
"<|IMAGE_END|>": 1003,
"<|VIDEO_START|>": 1004,
"<|VIDEO_END|>": 1005,
}
return token_id_map.get(token, 999)
def test_prompt_token_ids2outputs_only_prompt_token_ids(self):
test_prompt_token_ids = [101, 999, 998, 997, 102]
request = {
"prompt_token_ids": test_prompt_token_ids,
}
outputs = self.data_processor.prompt_token_ids2outputs(request)
prompt_len = len(test_prompt_token_ids)
self.assertEqual(
outputs["input_ids"],
test_prompt_token_ids,
f"input_ids mismatch: actual {outputs['input_ids']}, expected {test_prompt_token_ids}",
)
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
expected_position_ids = [[i] * 3 for i in range(prompt_len)]
self.assertEqual(outputs["position_ids"], expected_position_ids)
self.assertEqual(outputs["cur_position"], prompt_len)
self.assertEqual(len(outputs["images"]), 0)
self.assertEqual(len(outputs["grid_thw"]), 0)
self.assertEqual(len(outputs["mm_positions"]), 0)
self.assertEqual(len(outputs["mm_hashes"]), 0)
self.assertEqual(outputs["video_cnt"], 0)
self.assertEqual(outputs["num_input_image_tokens"], 0)
self.assertEqual(outputs["num_input_video_tokens"], 0)
def test_prompt_token_ids2outputs_with_messages_no_mm(self):
test_prompt_token_ids = [101, 999, 998, 997, 102]
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [{"role": "user", "content": "Hello World"}],
}
self.data_processor.extract_mm_items.return_value = ([], [], [], [], None, [], [])
outputs = self.data_processor.prompt_token_ids2outputs(request)
prompt_len = len(test_prompt_token_ids)
self.assertEqual(outputs["input_ids"], test_prompt_token_ids)
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
expected_position_ids = [[i] * 3 for i in range(prompt_len)]
self.assertEqual(outputs["position_ids"], expected_position_ids)
self.assertEqual(outputs["cur_position"], prompt_len)
self.assertEqual(len(outputs["images"]), 0)
self.assertEqual(outputs["video_cnt"], 0)
self.assertEqual(outputs["num_input_image_tokens"], 0)
def test_prompt_token_ids2outputs_add_image(self):
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
mock_img = MagicMock()
mock_img.height = 224
mock_img.width = 224
mock_img.convert.return_value = mock_img
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[mock_img],
[],
["img_uuid"],
[],
None,
[],
[{"type": "image", "data": mock_img}],
)
mock_resize = (None, (2, 4))
self.data_processor.image_preprocessor.get_smarted_resize.return_value = mock_resize
mock_preprocess = {"pixel_values": np.random.randn(1, 16, 16, 3), "image_grid_thw": np.array([[2, 4]])}
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
# self.data_processor._compute_3d_positions = MagicMock(return_value=[[i]*3 for i in range(4)])
outputs = self.data_processor.prompt_token_ids2outputs(request)
self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102])
self.assertEqual(
outputs["token_type_ids"],
[
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["image"],
IDS_TYPE_FLAG["image"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
],
)
self.assertEqual(len(outputs["position_ids"]), 6)
self.assertEqual(outputs["cur_position"], 6)
self.assertEqual(len(outputs["images"]), 1)
self.assertIsNotNone(outputs["images"][0])
self.assertEqual(outputs["num_input_image_tokens"], 2)
self.assertEqual(len(outputs["mm_positions"]), 1)
self.assertEqual(len(outputs["mm_hashes"]), 1)
self.assertEqual(len(outputs["grid_thw"]), 1)
self.assertEqual(len(outputs["image_type_ids"]), 1)
def test_prompt_token_ids2outputs_add_processed_image(self):
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
mock_img_data = np.random.randn(8, 28, 28)
mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)})
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[mock_img_cache],
[],
["img_uuid"],
[],
None,
[],
[{"type": "image", "data": mock_img_cache}],
)
outputs = self.data_processor.prompt_token_ids2outputs(request)
self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102])
self.assertEqual(
outputs["token_type_ids"],
[
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["image"],
IDS_TYPE_FLAG["image"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
],
)
self.assertEqual(len(outputs["position_ids"]), 20)
self.assertEqual(outputs["cur_position"], 8)
self.assertEqual(len(outputs["images"]), 1)
self.assertIsNotNone(outputs["images"][0])
self.assertEqual(len(outputs["mm_positions"]), 1)
self.assertEqual(outputs["mm_hashes"][0], "img_uuid")
self.assertEqual(len(outputs["grid_thw"]), 1)
self.assertEqual(len(outputs["image_type_ids"]), 1)
def test_prompt_token_ids2outputs_add_video(self):
test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]
mock_frame1 = MagicMock()
mock_frame1.height = 224
mock_frame1.width = 224
mock_frame1.convert.return_value = mock_frame1
mock_frame2 = MagicMock()
mock_frame2.height = 224
mock_frame2.width = 224
mock_frame2.convert.return_value = mock_frame2
frames = [mock_frame1, mock_frame2]
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[],
[frames],
[],
["vid_uuid"],
None,
[],
[{"type": "video", "data": frames}],
)
self.data_processor._load_and_process_video = MagicMock(return_value=frames)
patches_h, patches_w = 4, 4
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
mock_preprocess = {
"pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3),
"video_grid_thw": np.array([[patches_h, patches_w]] * 2),
}
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
outputs = self.data_processor.prompt_token_ids2outputs(request)
self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102])
self.assertEqual(
outputs["token_type_ids"],
[
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
],
)
self.assertEqual(len(outputs["position_ids"]), 8)
self.assertEqual(outputs["cur_position"], 6)
self.assertEqual(len(outputs["images"]), 1)
self.assertIsNotNone(outputs["images"][0])
self.assertEqual(len(outputs["mm_positions"]), 1)
self.assertEqual(outputs["mm_hashes"][0], "vid_uuid")
self.assertEqual(len(outputs["grid_thw"]), 1)
self.assertEqual(len(outputs["image_type_ids"]), 2)
self.assertEqual(outputs["num_input_video_tokens"], 4)
def test_prompt_token_ids2outputs_add_processed_video(self):
test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]
t, h, w = 2, 4, 4
spatial_conv_size = self.data_processor.spatial_conv_size
temporal_conv_size = self.data_processor.temporal_conv_size
token_per_frame = (h // spatial_conv_size) * (w // spatial_conv_size)
num_tokens = (t // temporal_conv_size) * token_per_frame
mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28)
mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)})
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[],
[mock_frames_cache],
[],
["vid_uuid"],
None,
[],
[{"type": "video", "data": mock_frames_cache}],
)
outputs = self.data_processor.prompt_token_ids2outputs(request)
self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102])
self.assertEqual(
outputs["token_type_ids"],
[
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
],
)
self.assertEqual(len(outputs["position_ids"]), 8)
self.assertEqual(outputs["cur_position"], 6)
self.assertEqual(len(outputs["images"]), 1)
self.assertIsNotNone(outputs["images"][0])
self.assertEqual(len(outputs["mm_positions"]), 1)
self.assertEqual(outputs["mm_hashes"][0], "vid_uuid")
self.assertEqual(len(outputs["grid_thw"]), 1)
self.assertEqual(len(outputs["image_type_ids"]), 2)
def test_prompt_token_ids2outputs_add_image_token_len_mismatch(self):
test_prompt_token_ids = [101, 1002, 1001, 1001, 1001, 1003, 102]
mock_img = MagicMock()
mock_img.height = 224
mock_img.width = 224
mock_img.convert.return_value = mock_img
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[mock_img],
[],
["img_uuid"],
[],
None,
[],
[{"type": "image", "data": mock_img}],
)
patches_h, patches_w = 8, 8
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
mock_preprocess = {
"pixel_values": np.random.randn(1, patches_h, patches_w, 3),
"image_grid_thw": np.array([[patches_h, patches_w]]),
}
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
with self.assertRaises(ValueError) as ctx:
self.data_processor.prompt_token_ids2outputs(request)
self.assertIn("image tokens num not match the size", str(ctx.exception))
def test_prompt_token_ids2outputs_add_processed_image_token_len_mismatch(self):
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
spatial_conv_size = self.data_processor.spatial_conv_size
num_tokens = 4
mock_img_data = np.random.randn(num_tokens * (spatial_conv_size**2), 28, 28)
mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)})
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[mock_img_cache],
[],
["img_uuid"],
[],
None,
[],
[{"type": "image", "data": mock_img_cache}],
)
with self.assertRaises(ValueError) as ctx:
self.data_processor.prompt_token_ids2outputs(request)
self.assertIn("image tokens num not match the size", str(ctx.exception))
def test_prompt_token_ids2outputs_add_video_token_len_mismatch(self):
test_prompt_token_ids = [101, 1004, 1001, 1001, 1005, 102]
mock_frame1 = MagicMock()
mock_frame1.height = 224
mock_frame1.width = 224
mock_frame1.convert.return_value = mock_frame1
mock_frame2 = MagicMock()
mock_frame2.height = 224
mock_frame2.width = 224
mock_frame2.convert.return_value = mock_frame2
frames = [mock_frame1, mock_frame2]
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[],
[frames],
[],
["vid_uuid"],
None,
[],
[{"type": "video", "data": frames}],
)
self.data_processor._load_and_process_video = MagicMock(return_value=frames)
patches_h, patches_w = 8, 8
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
mock_preprocess = {
"pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3),
"video_grid_thw": np.array([[patches_h, patches_w]] * 2),
}
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
with self.assertRaises(ValueError) as ctx:
self.data_processor.prompt_token_ids2outputs(request)
self.assertIn("video tokens num not match the size", str(ctx.exception))
def test_prompt_token_ids2outputs_add_processed_video_token_len_mismatch(self):
test_prompt_token_ids = [101, 1004, 1001, 1005, 102]
t, h, w = 2, 8, 8
spatial_conv_size = self.data_processor.spatial_conv_size
temporal_conv_size = self.data_processor.temporal_conv_size
num_tokens = 4
mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28)
mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)})
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[],
[mock_frames_cache],
[],
["vid_uuid"],
None,
[],
[{"type": "video", "data": mock_frames_cache}],
)
with self.assertRaises(ValueError) as ctx:
self.data_processor.prompt_token_ids2outputs(request)
self.assertIn("video tokens num not match the size", str(ctx.exception))
def test_extract_mm_items(self):
"""Test extract_mm_items with various scenarios: basic items, video, and missing data error"""
self._restore_real_extract_mm_items()
# Test basic multimodal items (image + video)
request = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "image", "data": Image.new("RGB", (224, 224)), "uuid": "img1"},
{"type": "video", "data": [Image.new("RGB", (224, 224))], "uuid": "vid1"},
],
}
]
}
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
mock_parse.return_value = request["messages"]
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = (
self.data_processor.extract_mm_items(request)
)
self.assertEqual(len(images), 1)
self.assertEqual(len(videos), 1)
self.assertEqual(image_uuid[0], "img1")
self.assertEqual(video_uuid[0], "vid1")
self.assertEqual(len(mm_items), 2)
# Test missing data error when cache is disabled
self.data_processor.enable_processor_cache = False
request = {"messages": [{"role": "user", "content": [{"type": "image", "uuid": "img1"}]}]}
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
mock_parse.return_value = request["messages"]
with self.assertRaises(ValueError) as ctx:
self.data_processor.extract_mm_items(request)
self.assertIn("Missing items cannot be retrieved", str(ctx.exception))
class TestDataProcessor(unittest.TestCase):
def setUp(self):
"""Set up test environment"""
self.mock_tokenizer = MagicMock()
def mock_convert_tokens_to_ids(x):
if isinstance(x, list):
return [hash(str(token)) % 10000 for token in x]
return hash(str(x)) % 10000
self.mock_tokenizer.convert_tokens_to_ids = MagicMock(side_effect=mock_convert_tokens_to_ids)
self.mock_tokenizer.encode = MagicMock(return_value={"input_ids": [1, 2, 3]})
self.mock_tokenizer.decode = MagicMock(return_value="decoded_text")
self.mock_tokenizer.tokenize = MagicMock(return_value=["token1", "token2"])
self.mock_tokenizer.ignored_index = -100
self.mock_tokenizer.chat_template = MagicMock()
self.mock_tokenizer.apply_chat_template = MagicMock(return_value="formatted_prompt")
self.mock_image_preprocessor = MagicMock()
self.mock_image_preprocessor.get_smarted_resize = MagicMock(return_value=((224, 224), (16, 16)))
self.mock_image_preprocessor.preprocess = MagicMock(
return_value={
"pixel_values": np.random.rand(256, 3 * 14 * 14).astype(np.float32),
"image_grid_thw": np.array([[1, 16, 16]]),
}
)
self.mock_image_preprocessor.from_pretrained = MagicMock(return_value=self.mock_image_preprocessor)
with patch(
"fastdeploy.input.ernie4_5_vl_processor.process.AdaptiveImageProcessor",
self.mock_image_preprocessor,
):
with patch("fastdeploy.input.ernie4_5_vl_processor.process.Ernie4_5Tokenizer") as mock_tokenizer_class:
mock_tokenizer_class.from_pretrained = MagicMock(return_value=self.mock_tokenizer)
mock_tokenizer_class.resource_files_names = {"vocab_file": "tokenizer.model"}
with patch("os.path.exists", return_value=True):
self.processor = DataProcessor(
tokenizer_name="test_model",
image_preprocessor_name="test_model",
)
def _create_outputs(self):
"""Helper to create outputs dict"""
return {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"mm_positions": [],
"mm_hashes": [],
"cur_position": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
def _mock_video_processing(self, mock_frames=None):
"""Helper to mock video processing"""
if mock_frames is None:
mock_frames = [Image.new("RGB", (224, 224)) for _ in range(4)]
mock_read = patch("fastdeploy.input.ernie4_5_vl_processor.process.read_video_decord")
mock_frames_read = patch("fastdeploy.input.ernie4_5_vl_processor.process.read_frames_decord")
mock_render = patch("fastdeploy.input.ernie4_5_vl_processor.process.render_frame_timestamp")
return mock_read, mock_frames_read, mock_render, mock_frames
def _setup_video_mocks(self, mock_read, mock_frames_read, mock_render, mock_frames):
"""Setup video processing mocks"""
mock_read.return_value = (None, {"duration": 2.0}, "test_path")
mock_frames_read.return_value = (
[np.array(f) for f in mock_frames],
None,
[0.0, 0.5, 1.0, 1.5] if len(mock_frames) == 4 else [float(i) * 0.5 for i in range(len(mock_frames))],
)
mock_render.side_effect = lambda img, ts: (Image.fromarray(img) if isinstance(img, np.ndarray) else img)
self.mock_image_preprocessor.preprocess.return_value = {
"pixel_values_videos": np.random.rand(len(mock_frames), 256, 3 * 14 * 14).astype(np.float32),
"video_grid_thw": np.array([[len(mock_frames), 16, 16]]),
}
def test_train_and_eval(self):
"""Test training and evaluation mode switching"""
self.assertTrue(self.processor.is_training)
self.processor.eval()
self.assertFalse(self.processor.is_training)
self.processor.train()
self.assertTrue(self.processor.is_training)
def test_build_token_type_mapping(self):
"""Test token type mapping construction"""
mapping = self.processor._build_token_type_mapping()
for token in [
self.processor.IMG_START,
self.processor.IMG_END,
self.processor.VID_START,
self.processor.VID_END,
]:
self.assertEqual(mapping[token], IDS_TYPE_FLAG["image"])
self.assertEqual(mapping[self.processor.image_patch_id], IDS_TYPE_FLAG["image"])
def test_add_text_and_special_token(self):
"""Test adding text and special tokens"""
outputs = self._create_outputs()
self.processor._add_text("hello", outputs)
self.assertEqual(len(outputs["input_ids"]), 3)
self.assertEqual(outputs["cur_position"], 3)
outputs2 = self._create_outputs()
self.processor._add_text([1, 2, 3, 4, 5], outputs2)
self.assertEqual(len(outputs2["input_ids"]), 5)
outputs3 = self._create_outputs()
self.processor._add_special_token("<|begin_of_sentence|>", outputs3)
self.processor._add_special_token(12345, outputs3)
self.assertEqual(len(outputs3["input_ids"]), 2)
def test_compute_3d_positions(self):
"""Test 3D position computation"""
pos_ids = self.processor._compute_3d_positions(t=2, h=16, w=16, start_idx=10)
self.assertIsInstance(pos_ids, list)
self.assertGreater(len(pos_ids), 0)
self.assertEqual(len(pos_ids[0]), 3)
pos_ids2 = self.processor._compute_3d_positions(t=1, h=16, w=16, start_idx=0)
expected_len = 1 * (16 // self.processor.spatial_conv_size) ** 2
self.assertEqual(len(pos_ids2), expected_len)
def test_set_video_frame_args_comprehensive(self):
"""Test _set_video_frame_args with various scenarios"""
# Valid cases
result = self.processor._set_video_frame_args(
{
"target_frames": 32,
"fps": -1,
"min_frames": 16,
"max_frames": 64,
"frames_sample": "leading",
},
{"duration": 10.0},
)
self.assertEqual(result["target_frames"], 32)
result = self.processor._set_video_frame_args(
{
"target_frames": -1,
"fps": 2,
"min_frames": 16,
"max_frames": 64,
"frames_sample": "leading",
},
{"duration": 10.0},
)
self.assertIsNotNone(result)
# Error cases
with self.assertRaises(ValueError):
self.processor._set_video_frame_args(
{
"target_frames": -1,
"fps": -1,
"min_frames": 16,
"max_frames": 64,
"frames_sample": "leading",
},
{"duration": 10.0},
)
with self.assertRaises(ValueError):
self.processor._set_video_frame_args(
{
"target_frames": 10,
"fps": 2,
"min_frames": 1,
"max_frames": 100,
"frames_sample": "leading",
},
{"duration": 10.0},
)
with self.assertRaises(ValueError):
self.processor._set_video_frame_args(
{
"target_frames": 5,
"fps": -1,
"min_frames": 10,
"max_frames": 100,
"frames_sample": "leading",
},
{"duration": 10.0},
)
with self.assertRaises(ValueError):
self.processor._set_video_frame_args(
{
"target_frames": 200,
"fps": -1,
"min_frames": 1,
"max_frames": 100,
"frames_sample": "leading",
},
{"duration": 10.0},
)
with self.assertRaises(ValueError):
self.processor._set_video_frame_args(
{
"target_frames": -1,
"fps": 2,
"min_frames": 100,
"max_frames": 10,
"frames_sample": "leading",
},
{"duration": 10.0},
)
# Adjustment cases
result = self.processor._set_video_frame_args(
{
"target_frames": -1,
"fps": 1,
"min_frames": 10,
"max_frames": 100,
"frames_sample": "leading",
},
{"duration": 1.0},
)
self.assertEqual(result["target_frames"], 10)
self.assertEqual(result["fps"], -1)
result = self.processor._set_video_frame_args(
{
"target_frames": -1,
"fps": 10,
"min_frames": 1,
"max_frames": 100,
"frames_sample": "leading",
},
{"duration": 100.0},
)
self.assertEqual(result["target_frames"], 100)
self.assertEqual(result["fps"], -1)
def test_text2ids_comprehensive(self):
"""Test text2ids with various scenarios"""
# Text only
outputs = self.processor.text2ids("Hello world")
self.assertIn("input_ids", outputs)
self.assertEqual(len(outputs["images"]), 0)
# Empty text
outputs = self.processor.text2ids("")
self.assertEqual(len(outputs["input_ids"]), 0)
# With image placeholder
mock_image = Image.new("RGB", (224, 224))
outputs = self.processor.text2ids("Hello <|image@placeholder|> world", images=[mock_image])
self.assertGreater(len(outputs["input_ids"]), 0)
self.assertGreater(len(outputs["images"]), 0)
# With cached image
cached_image = (
np.random.rand(256, 3 * 14 * 14).astype(np.float32),
{"thw": (1, 16, 16)},
)
outputs = self.processor.text2ids(
"Hello <|image@placeholder|> world",
images=[cached_image],
image_uuid=["uuid"],
)
self.assertGreater(len(outputs["input_ids"]), 0)
# Multiple images
outputs = self.processor.text2ids(
"Hello <|image@placeholder|> world <|image@placeholder|> end",
images=[mock_image, mock_image],
)
self.assertEqual(len(outputs["images"]), 2)
# With video placeholder
mock_read, mock_frames_read, mock_render, mock_frames = self._mock_video_processing()
with mock_read as mr, mock_frames_read as mfr, mock_render as mren:
mr.return_value = (None, {"duration": 2.0}, "test_path")
mfr.return_value = (
[np.array(f) for f in mock_frames],
None,
[0.0, 0.5, 1.0, 1.5],
)
mren.side_effect = lambda img, ts: (Image.fromarray(img) if isinstance(img, np.ndarray) else img)
self.mock_image_preprocessor.preprocess.return_value = {
"pixel_values_videos": np.random.rand(4, 256, 3 * 14 * 14).astype(np.float32),
"video_grid_thw": np.array([[4, 16, 16]]),
}
outputs = self.processor.text2ids("Hello <|video@placeholder|> world", videos=["test_video.mp4"])
self.assertGreater(len(outputs["input_ids"]), 0)
# Cached video
cached_video = (
np.random.rand(256, 3 * 14 * 14).astype(np.float32),
{"thw": (4, 16, 16)},
)
outputs = self.processor.text2ids(
"Hello <|video@placeholder|> world",
videos=[cached_video],
video_uuid=["uuid"],
)
self.assertGreater(len(outputs["input_ids"]), 0)
# Video dict format
mock_read, mock_frames_read, mock_render, mock_frames = self._mock_video_processing()
with mock_read as mr, mock_frames_read as mfr, mock_render as mren:
mr.return_value = (None, {"duration": 2.0}, "test_path")
mfr.return_value = (
[np.array(f) for f in mock_frames],
None,
[0.0, 0.5, 1.0, 1.5],
)
mren.side_effect = lambda img, ts: (Image.fromarray(img) if isinstance(img, np.ndarray) else img)
self.mock_image_preprocessor.preprocess.return_value = {
"pixel_values_videos": np.random.rand(4, 256, 3 * 14 * 14).astype(np.float32),
"video_grid_thw": np.array([[4, 16, 16]]),
}
outputs = self.processor.text2ids(
"Hello <|video@placeholder|> world",
videos=[{"video": "test.mp4", "fps": 2}],
)
self.assertGreater(len(outputs["input_ids"]), 0)
# Image and video together
mock_read, mock_frames_read, mock_render, mock_frames = self._mock_video_processing()
with mock_read as mr, mock_frames_read as mfr, mock_render as mren:
mr.return_value = (None, {"duration": 2.0}, "test_path")
mfr.return_value = (
[np.array(f) for f in mock_frames],
None,
[0.0, 0.5, 1.0, 1.5],
)
mren.side_effect = lambda img, ts: (Image.fromarray(img) if isinstance(img, np.ndarray) else img)
self.mock_image_preprocessor.preprocess.side_effect = [
{
"pixel_values": np.random.rand(256, 3 * 14 * 14).astype(np.float32),
"image_grid_thw": np.array([[1, 16, 16]]),
},
{
"pixel_values_videos": np.random.rand(4, 256, 3 * 14 * 14).astype(np.float32),
"video_grid_thw": np.array([[4, 16, 16]]),
},
]
outputs = self.processor.text2ids(
"Hello <|image@placeholder|> world <|video@placeholder|> end",
images=[mock_image],
videos=["test_video.mp4"],
)
self.assertGreater(len(outputs["input_ids"]), 0)
self.mock_image_preprocessor.preprocess.side_effect = None
def test_request2ids_comprehensive(self):
"""Test request2ids with various scenarios"""
self.processor.is_training = False
# Basic request with multimodal content - covers both text and image branches in one call
mock_image = Image.new("RGB", (224, 224))
request = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image", "data": mock_image, "uuid": "img1"},
],
}
],
"add_generation_prompt": True,
}
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
mock_parse.return_value = request["messages"]
outputs = self.processor.request2ids(request)
self.assertIn("input_ids", outputs)
# Error case: missing chat_template
self.processor.tokenizer.chat_template = None
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
mock_parse.return_value = [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]
with self.assertRaises(ValueError):
self.processor.request2ids(request)
self.processor.tokenizer.chat_template = MagicMock()
# Error case: unsupported role
request = {
"messages": [{"role": "invalid_role", "content": "Hello"}],
"add_generation_prompt": True,
}
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
mock_parse.return_value = [{"role": "invalid_role", "content": [{"type": "text", "text": "Hello"}]}]
with self.assertRaises(AssertionError):
self.processor.request2ids(request)
# Error case: missing cache when cache is disabled
self.processor.enable_processor_cache = False
request = {"messages": [{"role": "user", "content": [{"type": "image", "uuid": "img1"}]}]}
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
mock_parse.return_value = request["messages"]
with self.assertRaises(ValueError):
self.processor.request2ids(request)
def test_extract_labels(self):
"""Test label extraction"""
outputs = {"input_ids": [1, 2, 3, self.processor.sep_token_id, 4, 5], "labels": []}
self.processor.is_training = True
self.processor._extract_labels(outputs, ["target text"])
self.assertEqual(len(outputs["labels"]), len(outputs["input_ids"]))
# Multiple targets
outputs2 = {
"input_ids": [1, 2, 3, self.processor.sep_token_id, 4, 5, self.processor.sep_token_id, 6, 7],
"labels": [],
}
self.processor._extract_labels(outputs2, ["target1", "target2"])
self.assertEqual(len(outputs2["labels"]), len(outputs2["input_ids"]))
# Error case
outputs3 = {"input_ids": [1, 2, 3, self.processor.sep_token_id], "labels": []}
with self.assertRaises(AssertionError):
self.processor._extract_labels(outputs3, ["target1", "target2"])
def test_fancy_print(self):
"""Test fancy_print function"""
from fastdeploy.input.ernie4_5_vl_processor.process import fancy_print
test_cases = [
([1, 2, 3, self.processor.image_patch_id, 4, 5], self.processor.image_patch_id, None),
(
[
1,
2,
self.processor.image_patch_id,
self.processor.image_patch_id,
self.processor.image_patch_id,
4,
5,
],
self.processor.image_patch_id,
"<|IMAGE@",
),
([1, 2, 3, 4, 5], self.processor.image_patch_id, None),
]
for input_ids, image_patch_id, expected_contains in test_cases:
result = fancy_print(input_ids, self.mock_tokenizer, image_patch_id)
self.assertIsInstance(result, str)
if expected_contains:
self.assertIn(expected_contains, result)
def test_processor_cache_operations(self):
"""Test processor cache get/update and request2ids with cache"""
# Test get_processor_cache
mock_socket = MagicMock()
mock_socket.recv_multipart = MagicMock(return_value=(b"", b"pickled_data"))
with patch("fastdeploy.input.ernie4_5_vl_processor.process.pickle") as mock_pickle:
mock_pickle.loads = MagicMock(return_value=[{"data": "cached_item"}])
result = self.processor.get_processor_cache(mock_socket, ["hash1", "hash2"])
self.assertEqual(len(result), 1)
# Test update_processor_cache
mock_socket2 = MagicMock()
with patch("fastdeploy.input.ernie4_5_vl_processor.process.pickle"):
self.processor.update_processor_cache(
mock_socket2,
["hash1"],
[(np.array([1, 2, 3]), {"meta": "data"})],
)
mock_socket2.send_multipart.assert_called_once()
# Test request2ids with processor cache update
self.processor.is_training = False
self.processor.enable_processor_cache = True
mock_image = Image.new("RGB", (224, 224))
request = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "image", "data": mock_image, "uuid": "img1"},
],
}
],
"add_generation_prompt": True,
}
with patch("fastdeploy.input.ernie4_5_vl_processor.process.zmq") as mock_zmq:
mock_context = MagicMock()
mock_socket = MagicMock()
mock_socket.recv_multipart = MagicMock(return_value=(b"", b"pickled_data"))
mock_context.socket.return_value = mock_socket
mock_zmq.Context.return_value = mock_context
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
mock_parse.return_value = request["messages"]
with patch("fastdeploy.input.ernie4_5_vl_processor.process.pickle") as mock_pickle:
mock_pickle.loads = MagicMock(return_value=[])
with patch.object(self.processor, "text2ids") as mock_text2ids:
mock_text2ids.return_value = {
"input_ids": [1, 2, 3],
"token_type_ids": [0] * 3,
"position_ids": [[i] * 3 for i in range(3)],
"images": [np.random.rand(256, 3 * 14 * 14).astype(np.float32)],
"grid_thw": [np.array([[1, 16, 16]])],
"image_type_ids": [0],
"cur_position": 3,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"mm_positions": [],
"mm_hashes": ["hash1"],
}
with patch.object(self.processor, "update_processor_cache") as mock_update:
self.processor.request2ids(request)
mock_update.assert_called_once()
self.processor.enable_processor_cache = False
if __name__ == "__main__":
unittest.main()