diff --git a/tests/input/test_process_video.py b/tests/input/test_process_video.py new file mode 100644 index 000000000..edabb4a82 --- /dev/null +++ b/tests/input/test_process_video.py @@ -0,0 +1,370 @@ +import io +import math +import os +import tempfile +import unittest +from unittest.mock import patch + +import numpy as np +from PIL import Image as PILImage + +import fastdeploy.input.ernie4_5_vl_processor.process_video as process_video_module +from fastdeploy.input.ernie4_5_vl_processor.process_video import ( + get_frame_indices, + read_frames_decord, + read_video_decord, +) + + +class _MockFrame: + """Lightweight frame wrapper that mimics the real frame object.""" + + def __init__(self, arr): + self._arr = arr + + def asnumpy(self): + """Return the underlying numpy array.""" + return self._arr + + +class MockVideoReaderWrapper: + """ + Simple mock implementation of a video reader: + + - __len__ returns the total number of frames + - __getitem__ returns a _MockFrame(arr) + - get_avg_fps() returns fps + - Specific indices can be configured to raise errors in __getitem__ + """ + + def __init__( + self, + src, + num_threads=1, + vlen=12, + fps=6, + fail_indices=None, + h=4, + w=5, + c=3, + ): + self.src = src + self._vlen = vlen + self._fps = fps + self._fail = set(fail_indices or []) + self._h, self._w, self._c = h, w, c + + def __len__(self): + return self._vlen + + def get_avg_fps(self): + return self._fps + + def __getitem__(self, idx): + if idx < 0 or idx >= self._vlen: + raise IndexError("index out of range") + if idx in self._fail: + raise ValueError(f"forced fail at {idx}") + # Create a frame whose pixel value encodes the index (for easy debugging) + arr = np.zeros((self._h, self._w, self._c), dtype=np.uint8) + arr[:] = idx % 255 + return _MockFrame(arr) + + +class TestReadVideoDecord(unittest.TestCase): + def test_read_video_decord_with_wrapper(self): + """Test passing an existing VideoReaderWrapper instance directly.""" + # Patch VideoReaderWrapper in the target module so isinstance checks use our mock class + with patch.object(process_video_module, "VideoReaderWrapper", MockVideoReaderWrapper): + mock_reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5) + reader, meta, path = read_video_decord(mock_reader, save_to_disk=False) + + self.assertIs(reader, mock_reader) + self.assertEqual(meta["fps"], 5) + self.assertEqual(meta["num_of_frame"], 10) + self.assertTrue(math.isclose(meta["duration"], 10 / 5, rel_tol=1e-6)) + # The original reader object should be returned unchanged + self.assertIs(path, mock_reader) + + def test_read_video_decord_with_bytes(self): + """Test that bytes input is wrapped into BytesIO and passed to VideoReaderWrapper.""" + with patch.object(process_video_module, "VideoReaderWrapper", MockVideoReaderWrapper): + data = b"\x00\x01\x02\x03" + reader, meta, path = read_video_decord(data, save_to_disk=False) + + self.assertIsInstance(reader, MockVideoReaderWrapper) + self.assertEqual(meta["fps"], 6) + self.assertEqual(meta["num_of_frame"], 12) + self.assertTrue(math.isclose(meta["duration"], 12 / 6, rel_tol=1e-6)) + self.assertIsInstance(path, io.BytesIO) + + +class TestGetFrameIndices(unittest.TestCase): + def test_by_target_frames_middle(self): + """Test target_frames mode with 'middle' sampling strategy.""" + vlen = 12 + out = get_frame_indices( + vlen=vlen, + target_frames=4, + target_fps=-1, + frames_sample="middle", + input_fps=-1, + ) + # 12 frames split into 4 segments -> midpoints [1, 4, 7, 10] + self.assertEqual(out, [1, 4, 7, 10]) + + def test_by_target_frames_leading(self): + """Test target_frames mode with 'leading' sampling strategy.""" + vlen = 10 + out = get_frame_indices( + vlen=vlen, + target_frames=5, + target_fps=-1, + frames_sample="leading", + input_fps=-1, + ) + # 10 frames split into 5 segments -> segment starts [0, 2, 4, 6, 8] + self.assertEqual(out, [0, 2, 4, 6, 8]) + + def test_by_target_frames_rand(self): + """Test target_frames mode with 'rand' sampling strategy.""" + vlen = 10 + out = get_frame_indices( + vlen=vlen, + target_frames=4, + target_fps=-1, + frames_sample="rand", + input_fps=-1, + ) + self.assertEqual(len(out), 4) + self.assertTrue(all(0 <= i < vlen for i in out)) + + def test_by_target_frames_fix_start(self): + """Test target_frames mode with a fixed start offset.""" + vlen = 10 + out = get_frame_indices( + vlen=vlen, + target_frames=5, + target_fps=-1, + frames_sample="middle", # overridden by fix_start + fix_start=1, + input_fps=-1, + ) + # Segment starts [0, 2, 4, 6, 8] -> +1 => [1, 3, 5, 7, 9] + self.assertEqual(out, [1, 3, 5, 7, 9]) + + def test_target_frames_greater_than_vlen(self): + """Test that target_frames > vlen falls back to using vlen samples.""" + vlen = 5 + out = get_frame_indices( + vlen=vlen, + target_frames=10, + target_fps=-1, + frames_sample="middle", + input_fps=-1, + ) + self.assertEqual(len(out), vlen) + self.assertTrue(all(0 <= i < vlen for i in out)) + + def test_by_target_fps_middle(self): + """Test target_fps mode with 'middle' sampling strategy.""" + vlen, in_fps = 12, 6 + out = get_frame_indices( + vlen=vlen, + target_frames=-1, + target_fps=2, + frames_sample="middle", + input_fps=in_fps, + ) + # Roughly 4 frames expected + self.assertTrue(3 <= len(out) <= 5) + self.assertTrue(all(0 <= i < vlen for i in out)) + + def test_by_target_fps_leading(self): + """Test target_fps mode with 'leading' sampling strategy.""" + vlen, in_fps = 12, 6 + out = get_frame_indices( + vlen=vlen, + target_frames=-1, + target_fps=2, + frames_sample="leading", + input_fps=in_fps, + ) + self.assertTrue(3 <= len(out) <= 5) + self.assertTrue(all(0 <= i < vlen for i in out)) + + def test_by_target_fps_rand(self): + """Test target_fps mode with 'rand' sampling strategy.""" + vlen, in_fps = 12, 6 + out = get_frame_indices( + vlen=vlen, + target_frames=-1, + target_fps=2, + frames_sample="rand", + input_fps=in_fps, + ) + self.assertTrue(3 <= len(out) <= 5) + self.assertTrue(all(0 <= i < vlen for i in out)) + + def test_invalid_both_negative(self): + """Test that both target_frames and target_fps being negative raises ValueError.""" + with self.assertRaises(ValueError): + get_frame_indices( + vlen=10, + target_frames=-1, + target_fps=-1, + frames_sample="middle", + ) + + def test_invalid_both_specified(self): + """Test that specifying both target_frames and target_fps raises AssertionError.""" + with self.assertRaises(AssertionError): + get_frame_indices( + vlen=10, + target_frames=4, + target_fps=2, + frames_sample="middle", + input_fps=6, + ) + + def test_invalid_target_fps_missing_input(self): + """Test that target_fps > 0 with invalid input_fps raises AssertionError.""" + with self.assertRaises(AssertionError): + get_frame_indices( + vlen=10, + target_frames=-1, + target_fps=2, + frames_sample="middle", + input_fps=-1, + ) + + +class TestReadFramesDecord(unittest.TestCase): + def test_basic_read_no_save(self): + """Test normal frame reading without saving to disk.""" + reader = MockVideoReaderWrapper("dummy", vlen=8, fps=4) + meta = {"fps": 4, "duration": 8 / 4, "num_of_frame": 8} + + ret, idxs, ts = read_frames_decord( + video_path="dummy", + video_reader=reader, + video_meta=meta, + target_frames=4, + frames_sample="middle", + save_to_disk=False, + ) + + # Should return 4 PIL.Image instances + self.assertEqual(len(ret), 4) + for img in ret: + self.assertIsInstance(img, PILImage.Image) + + self.assertEqual(idxs, [0, 2, 4, 6]) + dur = meta["duration"] + n = meta["num_of_frame"] + for i, t in zip(idxs, ts): + self.assertTrue(math.isclose(t, i * dur / n, rel_tol=1e-6)) + + def test_read_and_save_to_disk(self): + """Test reading frames and saving them as PNG files on disk.""" + reader = MockVideoReaderWrapper("dummy", vlen=4, fps=2) + meta = {"fps": 2, "duration": 4 / 2, "num_of_frame": 4} + + with ( + tempfile.TemporaryDirectory() as tmpdir, + patch.object( + process_video_module, + "get_filename", + return_value="det_id", + ), + ): + ret, idxs, ts = read_frames_decord( + video_path="dummy", + video_reader=reader, + video_meta=meta, + target_frames=2, + frames_sample="leading", + save_to_disk=True, + cache_dir=tmpdir, + ) + + self.assertEqual(len(ret), 2) + for i, pth in enumerate(ret): + self.assertIsInstance(pth, str) + self.assertTrue(os.path.exists(pth)) + self.assertEqual(os.path.basename(pth), f"{i}.png") + + def test_fallback_previous_success(self): + """Test that a failed frame read falls back to a previous valid frame when possible.""" + reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5, fail_indices={3}) + meta = {"fps": 5, "duration": 10 / 5, "num_of_frame": 10} + idxs = [1, 2, 3, 6] + + ret, new_idxs, ts = read_frames_decord( + video_path="dummy", + video_reader=reader, + video_meta=meta, + frame_indices=idxs.copy(), + save_to_disk=False, + tol=5, + ) + + # Index 3 fails and should be replaced by 2 or 4 (previous/next search) + self.assertIn(new_idxs[2], (2, 4)) + self.assertEqual(len(ret), 4) + + def test_fallback_next_when_prev_fails(self): + """Test that when current and previous frames fail, a later frame is used as fallback.""" + reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5, fail_indices={2, 3}) + meta = {"fps": 5, "duration": 10 / 5, "num_of_frame": 10} + idxs = [1, 2, 3, 6] + + ret, new_idxs, ts = read_frames_decord( + video_path="dummy", + video_reader=reader, + video_meta=meta, + frame_indices=idxs.copy(), + save_to_disk=False, + tol=5, + ) + + # Frame 3 should eventually be replaced by 4 + self.assertEqual(new_idxs[2], 4) + self.assertEqual(len(ret), 4) + + def test_len_assert_when_no_fallback(self): + """Test that assertion is triggered when no valid fallback frame can be found.""" + + class FailAllAroundReader(MockVideoReaderWrapper): + """Reader that fails on index 1 and has too small length to find fallback.""" + + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + self._vlen = 2 + self._fps = 2 + self._fail = {1} + + def __getitem__(self, idx): + if idx in self._fail: + raise ValueError("fail hard") + return super().__getitem__(idx) + + reader = FailAllAroundReader("dummy") + meta = {"fps": 2, "duration": 2 / 2, "num_of_frame": 2} + + # Request 2 frames: index 0 succeeds, index 1 always fails, + # and tol=0 disallows searching neighbors -> stack and length assertion should fail + with self.assertRaises(AssertionError): + read_frames_decord( + video_path="dummy", + video_reader=reader, + video_meta=meta, + target_frames=2, + frames_sample="leading", + save_to_disk=False, + tol=0, + ) + + +if __name__ == "__main__": + unittest.main()