add test_process_video.py (#5011)

This commit is contained in:
Winters Montagne
2025-11-14 17:23:30 +08:00
committed by GitHub
parent 544ea9cbc2
commit b925533051

View File

@@ -0,0 +1,370 @@
import io
import math
import os
import tempfile
import unittest
from unittest.mock import patch
import numpy as np
from PIL import Image as PILImage
import fastdeploy.input.ernie4_5_vl_processor.process_video as process_video_module
from fastdeploy.input.ernie4_5_vl_processor.process_video import (
get_frame_indices,
read_frames_decord,
read_video_decord,
)
class _MockFrame:
"""Lightweight frame wrapper that mimics the real frame object."""
def __init__(self, arr):
self._arr = arr
def asnumpy(self):
"""Return the underlying numpy array."""
return self._arr
class MockVideoReaderWrapper:
"""
Simple mock implementation of a video reader:
- __len__ returns the total number of frames
- __getitem__ returns a _MockFrame(arr)
- get_avg_fps() returns fps
- Specific indices can be configured to raise errors in __getitem__
"""
def __init__(
self,
src,
num_threads=1,
vlen=12,
fps=6,
fail_indices=None,
h=4,
w=5,
c=3,
):
self.src = src
self._vlen = vlen
self._fps = fps
self._fail = set(fail_indices or [])
self._h, self._w, self._c = h, w, c
def __len__(self):
return self._vlen
def get_avg_fps(self):
return self._fps
def __getitem__(self, idx):
if idx < 0 or idx >= self._vlen:
raise IndexError("index out of range")
if idx in self._fail:
raise ValueError(f"forced fail at {idx}")
# Create a frame whose pixel value encodes the index (for easy debugging)
arr = np.zeros((self._h, self._w, self._c), dtype=np.uint8)
arr[:] = idx % 255
return _MockFrame(arr)
class TestReadVideoDecord(unittest.TestCase):
def test_read_video_decord_with_wrapper(self):
"""Test passing an existing VideoReaderWrapper instance directly."""
# Patch VideoReaderWrapper in the target module so isinstance checks use our mock class
with patch.object(process_video_module, "VideoReaderWrapper", MockVideoReaderWrapper):
mock_reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5)
reader, meta, path = read_video_decord(mock_reader, save_to_disk=False)
self.assertIs(reader, mock_reader)
self.assertEqual(meta["fps"], 5)
self.assertEqual(meta["num_of_frame"], 10)
self.assertTrue(math.isclose(meta["duration"], 10 / 5, rel_tol=1e-6))
# The original reader object should be returned unchanged
self.assertIs(path, mock_reader)
def test_read_video_decord_with_bytes(self):
"""Test that bytes input is wrapped into BytesIO and passed to VideoReaderWrapper."""
with patch.object(process_video_module, "VideoReaderWrapper", MockVideoReaderWrapper):
data = b"\x00\x01\x02\x03"
reader, meta, path = read_video_decord(data, save_to_disk=False)
self.assertIsInstance(reader, MockVideoReaderWrapper)
self.assertEqual(meta["fps"], 6)
self.assertEqual(meta["num_of_frame"], 12)
self.assertTrue(math.isclose(meta["duration"], 12 / 6, rel_tol=1e-6))
self.assertIsInstance(path, io.BytesIO)
class TestGetFrameIndices(unittest.TestCase):
def test_by_target_frames_middle(self):
"""Test target_frames mode with 'middle' sampling strategy."""
vlen = 12
out = get_frame_indices(
vlen=vlen,
target_frames=4,
target_fps=-1,
frames_sample="middle",
input_fps=-1,
)
# 12 frames split into 4 segments -> midpoints [1, 4, 7, 10]
self.assertEqual(out, [1, 4, 7, 10])
def test_by_target_frames_leading(self):
"""Test target_frames mode with 'leading' sampling strategy."""
vlen = 10
out = get_frame_indices(
vlen=vlen,
target_frames=5,
target_fps=-1,
frames_sample="leading",
input_fps=-1,
)
# 10 frames split into 5 segments -> segment starts [0, 2, 4, 6, 8]
self.assertEqual(out, [0, 2, 4, 6, 8])
def test_by_target_frames_rand(self):
"""Test target_frames mode with 'rand' sampling strategy."""
vlen = 10
out = get_frame_indices(
vlen=vlen,
target_frames=4,
target_fps=-1,
frames_sample="rand",
input_fps=-1,
)
self.assertEqual(len(out), 4)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_by_target_frames_fix_start(self):
"""Test target_frames mode with a fixed start offset."""
vlen = 10
out = get_frame_indices(
vlen=vlen,
target_frames=5,
target_fps=-1,
frames_sample="middle", # overridden by fix_start
fix_start=1,
input_fps=-1,
)
# Segment starts [0, 2, 4, 6, 8] -> +1 => [1, 3, 5, 7, 9]
self.assertEqual(out, [1, 3, 5, 7, 9])
def test_target_frames_greater_than_vlen(self):
"""Test that target_frames > vlen falls back to using vlen samples."""
vlen = 5
out = get_frame_indices(
vlen=vlen,
target_frames=10,
target_fps=-1,
frames_sample="middle",
input_fps=-1,
)
self.assertEqual(len(out), vlen)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_by_target_fps_middle(self):
"""Test target_fps mode with 'middle' sampling strategy."""
vlen, in_fps = 12, 6
out = get_frame_indices(
vlen=vlen,
target_frames=-1,
target_fps=2,
frames_sample="middle",
input_fps=in_fps,
)
# Roughly 4 frames expected
self.assertTrue(3 <= len(out) <= 5)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_by_target_fps_leading(self):
"""Test target_fps mode with 'leading' sampling strategy."""
vlen, in_fps = 12, 6
out = get_frame_indices(
vlen=vlen,
target_frames=-1,
target_fps=2,
frames_sample="leading",
input_fps=in_fps,
)
self.assertTrue(3 <= len(out) <= 5)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_by_target_fps_rand(self):
"""Test target_fps mode with 'rand' sampling strategy."""
vlen, in_fps = 12, 6
out = get_frame_indices(
vlen=vlen,
target_frames=-1,
target_fps=2,
frames_sample="rand",
input_fps=in_fps,
)
self.assertTrue(3 <= len(out) <= 5)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_invalid_both_negative(self):
"""Test that both target_frames and target_fps being negative raises ValueError."""
with self.assertRaises(ValueError):
get_frame_indices(
vlen=10,
target_frames=-1,
target_fps=-1,
frames_sample="middle",
)
def test_invalid_both_specified(self):
"""Test that specifying both target_frames and target_fps raises AssertionError."""
with self.assertRaises(AssertionError):
get_frame_indices(
vlen=10,
target_frames=4,
target_fps=2,
frames_sample="middle",
input_fps=6,
)
def test_invalid_target_fps_missing_input(self):
"""Test that target_fps > 0 with invalid input_fps raises AssertionError."""
with self.assertRaises(AssertionError):
get_frame_indices(
vlen=10,
target_frames=-1,
target_fps=2,
frames_sample="middle",
input_fps=-1,
)
class TestReadFramesDecord(unittest.TestCase):
def test_basic_read_no_save(self):
"""Test normal frame reading without saving to disk."""
reader = MockVideoReaderWrapper("dummy", vlen=8, fps=4)
meta = {"fps": 4, "duration": 8 / 4, "num_of_frame": 8}
ret, idxs, ts = read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
target_frames=4,
frames_sample="middle",
save_to_disk=False,
)
# Should return 4 PIL.Image instances
self.assertEqual(len(ret), 4)
for img in ret:
self.assertIsInstance(img, PILImage.Image)
self.assertEqual(idxs, [0, 2, 4, 6])
dur = meta["duration"]
n = meta["num_of_frame"]
for i, t in zip(idxs, ts):
self.assertTrue(math.isclose(t, i * dur / n, rel_tol=1e-6))
def test_read_and_save_to_disk(self):
"""Test reading frames and saving them as PNG files on disk."""
reader = MockVideoReaderWrapper("dummy", vlen=4, fps=2)
meta = {"fps": 2, "duration": 4 / 2, "num_of_frame": 4}
with (
tempfile.TemporaryDirectory() as tmpdir,
patch.object(
process_video_module,
"get_filename",
return_value="det_id",
),
):
ret, idxs, ts = read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
target_frames=2,
frames_sample="leading",
save_to_disk=True,
cache_dir=tmpdir,
)
self.assertEqual(len(ret), 2)
for i, pth in enumerate(ret):
self.assertIsInstance(pth, str)
self.assertTrue(os.path.exists(pth))
self.assertEqual(os.path.basename(pth), f"{i}.png")
def test_fallback_previous_success(self):
"""Test that a failed frame read falls back to a previous valid frame when possible."""
reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5, fail_indices={3})
meta = {"fps": 5, "duration": 10 / 5, "num_of_frame": 10}
idxs = [1, 2, 3, 6]
ret, new_idxs, ts = read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
frame_indices=idxs.copy(),
save_to_disk=False,
tol=5,
)
# Index 3 fails and should be replaced by 2 or 4 (previous/next search)
self.assertIn(new_idxs[2], (2, 4))
self.assertEqual(len(ret), 4)
def test_fallback_next_when_prev_fails(self):
"""Test that when current and previous frames fail, a later frame is used as fallback."""
reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5, fail_indices={2, 3})
meta = {"fps": 5, "duration": 10 / 5, "num_of_frame": 10}
idxs = [1, 2, 3, 6]
ret, new_idxs, ts = read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
frame_indices=idxs.copy(),
save_to_disk=False,
tol=5,
)
# Frame 3 should eventually be replaced by 4
self.assertEqual(new_idxs[2], 4)
self.assertEqual(len(ret), 4)
def test_len_assert_when_no_fallback(self):
"""Test that assertion is triggered when no valid fallback frame can be found."""
class FailAllAroundReader(MockVideoReaderWrapper):
"""Reader that fails on index 1 and has too small length to find fallback."""
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
self._vlen = 2
self._fps = 2
self._fail = {1}
def __getitem__(self, idx):
if idx in self._fail:
raise ValueError("fail hard")
return super().__getitem__(idx)
reader = FailAllAroundReader("dummy")
meta = {"fps": 2, "duration": 2 / 2, "num_of_frame": 2}
# Request 2 frames: index 0 succeeds, index 1 always fails,
# and tol=0 disallows searching neighbors -> stack and length assertion should fail
with self.assertRaises(AssertionError):
read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
target_frames=2,
frames_sample="leading",
save_to_disk=False,
tol=0,
)
if __name__ == "__main__":
unittest.main()