Files
FastDeploy/tests/input/test_process_video.py
Echo-Nie 1b1bfab341 [CI] Add unittest (#5328)
* add test_worker_eplb

* remove tesnsor_wise_fp8

* add copyright
2025-12-09 19:19:42 +08:00

387 lines
13 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import io
import math
import os
import tempfile
import unittest
from unittest.mock import patch
import numpy as np
from PIL import Image as PILImage
import fastdeploy.input.ernie4_5_vl_processor.process_video as process_video_module
from fastdeploy.input.ernie4_5_vl_processor.process_video import (
get_frame_indices,
read_frames_decord,
read_video_decord,
)
class _MockFrame:
"""Lightweight frame wrapper that mimics the real frame object."""
def __init__(self, arr):
self._arr = arr
def asnumpy(self):
"""Return the underlying numpy array."""
return self._arr
class MockVideoReaderWrapper:
"""
Simple mock implementation of a video reader:
- __len__ returns the total number of frames
- __getitem__ returns a _MockFrame(arr)
- get_avg_fps() returns fps
- Specific indices can be configured to raise errors in __getitem__
"""
def __init__(
self,
src,
num_threads=1,
vlen=12,
fps=6,
fail_indices=None,
h=4,
w=5,
c=3,
):
self.src = src
self._vlen = vlen
self._fps = fps
self._fail = set(fail_indices or [])
self._h, self._w, self._c = h, w, c
def __len__(self):
return self._vlen
def get_avg_fps(self):
return self._fps
def __getitem__(self, idx):
if idx < 0 or idx >= self._vlen:
raise IndexError("index out of range")
if idx in self._fail:
raise ValueError(f"forced fail at {idx}")
# Create a frame whose pixel value encodes the index (for easy debugging)
arr = np.zeros((self._h, self._w, self._c), dtype=np.uint8)
arr[:] = idx % 255
return _MockFrame(arr)
class TestReadVideoDecord(unittest.TestCase):
def test_read_video_decord_with_wrapper(self):
"""Test passing an existing VideoReaderWrapper instance directly."""
# Patch VideoReaderWrapper in the target module so isinstance checks use our mock class
with patch.object(process_video_module, "VideoReaderWrapper", MockVideoReaderWrapper):
mock_reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5)
reader, meta, path = read_video_decord(mock_reader, save_to_disk=False)
self.assertIs(reader, mock_reader)
self.assertEqual(meta["fps"], 5)
self.assertEqual(meta["num_of_frame"], 10)
self.assertTrue(math.isclose(meta["duration"], 10 / 5, rel_tol=1e-6))
# The original reader object should be returned unchanged
self.assertIs(path, mock_reader)
def test_read_video_decord_with_bytes(self):
"""Test that bytes input is wrapped into BytesIO and passed to VideoReaderWrapper."""
with patch.object(process_video_module, "VideoReaderWrapper", MockVideoReaderWrapper):
data = b"\x00\x01\x02\x03"
reader, meta, path = read_video_decord(data, save_to_disk=False)
self.assertIsInstance(reader, MockVideoReaderWrapper)
self.assertEqual(meta["fps"], 6)
self.assertEqual(meta["num_of_frame"], 12)
self.assertTrue(math.isclose(meta["duration"], 12 / 6, rel_tol=1e-6))
self.assertIsInstance(path, io.BytesIO)
class TestGetFrameIndices(unittest.TestCase):
def test_by_target_frames_middle(self):
"""Test target_frames mode with 'middle' sampling strategy."""
vlen = 12
out = get_frame_indices(
vlen=vlen,
target_frames=4,
target_fps=-1,
frames_sample="middle",
input_fps=-1,
)
# 12 frames split into 4 segments -> midpoints [1, 4, 7, 10]
self.assertEqual(out, [1, 4, 7, 10])
def test_by_target_frames_leading(self):
"""Test target_frames mode with 'leading' sampling strategy."""
vlen = 10
out = get_frame_indices(
vlen=vlen,
target_frames=5,
target_fps=-1,
frames_sample="leading",
input_fps=-1,
)
# 10 frames split into 5 segments -> segment starts [0, 2, 4, 6, 8]
self.assertEqual(out, [0, 2, 4, 6, 8])
def test_by_target_frames_rand(self):
"""Test target_frames mode with 'rand' sampling strategy."""
vlen = 10
out = get_frame_indices(
vlen=vlen,
target_frames=4,
target_fps=-1,
frames_sample="rand",
input_fps=-1,
)
self.assertEqual(len(out), 4)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_by_target_frames_fix_start(self):
"""Test target_frames mode with a fixed start offset."""
vlen = 10
out = get_frame_indices(
vlen=vlen,
target_frames=5,
target_fps=-1,
frames_sample="middle", # overridden by fix_start
fix_start=1,
input_fps=-1,
)
# Segment starts [0, 2, 4, 6, 8] -> +1 => [1, 3, 5, 7, 9]
self.assertEqual(out, [1, 3, 5, 7, 9])
def test_target_frames_greater_than_vlen(self):
"""Test that target_frames > vlen falls back to using vlen samples."""
vlen = 5
out = get_frame_indices(
vlen=vlen,
target_frames=10,
target_fps=-1,
frames_sample="middle",
input_fps=-1,
)
self.assertEqual(len(out), vlen)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_by_target_fps_middle(self):
"""Test target_fps mode with 'middle' sampling strategy."""
vlen, in_fps = 12, 6
out = get_frame_indices(
vlen=vlen,
target_frames=-1,
target_fps=2,
frames_sample="middle",
input_fps=in_fps,
)
# Roughly 4 frames expected
self.assertTrue(3 <= len(out) <= 5)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_by_target_fps_leading(self):
"""Test target_fps mode with 'leading' sampling strategy."""
vlen, in_fps = 12, 6
out = get_frame_indices(
vlen=vlen,
target_frames=-1,
target_fps=2,
frames_sample="leading",
input_fps=in_fps,
)
self.assertTrue(3 <= len(out) <= 5)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_by_target_fps_rand(self):
"""Test target_fps mode with 'rand' sampling strategy."""
vlen, in_fps = 12, 6
out = get_frame_indices(
vlen=vlen,
target_frames=-1,
target_fps=2,
frames_sample="rand",
input_fps=in_fps,
)
self.assertTrue(3 <= len(out) <= 5)
self.assertTrue(all(0 <= i < vlen for i in out))
def test_invalid_both_negative(self):
"""Test that both target_frames and target_fps being negative raises ValueError."""
with self.assertRaises(ValueError):
get_frame_indices(
vlen=10,
target_frames=-1,
target_fps=-1,
frames_sample="middle",
)
def test_invalid_both_specified(self):
"""Test that specifying both target_frames and target_fps raises AssertionError."""
with self.assertRaises(AssertionError):
get_frame_indices(
vlen=10,
target_frames=4,
target_fps=2,
frames_sample="middle",
input_fps=6,
)
def test_invalid_target_fps_missing_input(self):
"""Test that target_fps > 0 with invalid input_fps raises AssertionError."""
with self.assertRaises(AssertionError):
get_frame_indices(
vlen=10,
target_frames=-1,
target_fps=2,
frames_sample="middle",
input_fps=-1,
)
class TestReadFramesDecord(unittest.TestCase):
def test_basic_read_no_save(self):
"""Test normal frame reading without saving to disk."""
reader = MockVideoReaderWrapper("dummy", vlen=8, fps=4)
meta = {"fps": 4, "duration": 8 / 4, "num_of_frame": 8}
ret, idxs, ts = read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
target_frames=4,
frames_sample="middle",
save_to_disk=False,
)
# Should return 4 PIL.Image instances
self.assertEqual(len(ret), 4)
for img in ret:
self.assertIsInstance(img, PILImage.Image)
self.assertEqual(idxs, [0, 2, 4, 6])
dur = meta["duration"]
n = meta["num_of_frame"]
for i, t in zip(idxs, ts):
self.assertTrue(math.isclose(t, i * dur / n, rel_tol=1e-6))
def test_read_and_save_to_disk(self):
"""Test reading frames and saving them as PNG files on disk."""
reader = MockVideoReaderWrapper("dummy", vlen=4, fps=2)
meta = {"fps": 2, "duration": 4 / 2, "num_of_frame": 4}
with (
tempfile.TemporaryDirectory() as tmpdir,
patch.object(
process_video_module,
"get_filename",
return_value="det_id",
),
):
ret, idxs, ts = read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
target_frames=2,
frames_sample="leading",
save_to_disk=True,
cache_dir=tmpdir,
)
self.assertEqual(len(ret), 2)
for i, pth in enumerate(ret):
self.assertIsInstance(pth, str)
self.assertTrue(os.path.exists(pth))
self.assertEqual(os.path.basename(pth), f"{i}.png")
def test_fallback_previous_success(self):
"""Test that a failed frame read falls back to a previous valid frame when possible."""
reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5, fail_indices={3})
meta = {"fps": 5, "duration": 10 / 5, "num_of_frame": 10}
idxs = [1, 2, 3, 6]
ret, new_idxs, ts = read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
frame_indices=idxs.copy(),
save_to_disk=False,
tol=5,
)
# Index 3 fails and should be replaced by 2 or 4 (previous/next search)
self.assertIn(new_idxs[2], (2, 4))
self.assertEqual(len(ret), 4)
def test_fallback_next_when_prev_fails(self):
"""Test that when current and previous frames fail, a later frame is used as fallback."""
reader = MockVideoReaderWrapper("dummy", vlen=10, fps=5, fail_indices={2, 3})
meta = {"fps": 5, "duration": 10 / 5, "num_of_frame": 10}
idxs = [1, 2, 3, 6]
ret, new_idxs, ts = read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
frame_indices=idxs.copy(),
save_to_disk=False,
tol=5,
)
# Frame 3 should eventually be replaced by 4
self.assertEqual(new_idxs[2], 4)
self.assertEqual(len(ret), 4)
def test_len_assert_when_no_fallback(self):
"""Test that assertion is triggered when no valid fallback frame can be found."""
class FailAllAroundReader(MockVideoReaderWrapper):
"""Reader that fails on index 1 and has too small length to find fallback."""
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
self._vlen = 2
self._fps = 2
self._fail = {1}
def __getitem__(self, idx):
if idx in self._fail:
raise ValueError("fail hard")
return super().__getitem__(idx)
reader = FailAllAroundReader("dummy")
meta = {"fps": 2, "duration": 2 / 2, "num_of_frame": 2}
# Request 2 frames: index 0 succeeds, index 1 always fails,
# and tol=0 disallows searching neighbors -> stack and length assertion should fail
with self.assertRaises(AssertionError):
read_frames_decord(
video_path="dummy",
video_reader=reader,
video_meta=meta,
target_frames=2,
frames_sample="leading",
save_to_disk=False,
tol=0,
)
if __name__ == "__main__":
unittest.main()