Files
FastDeploy/fastdeploy/multimodal/video.py
ApplEOFDiscord b71cbb466d [Feature] remove dependency on enable_mm and refine multimodal's code (#3014)
* remove dependency on enable_mm

* fix codestyle check error

* fix codestyle check error

* update docs

* resolve conflicts on model config

* fix unit test error

* fix code style check error

---------

Co-authored-by: shige <1021937542@qq.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-08-01 20:01:18 +08:00

165 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import base64
import numpy as np
import numpy.typing as npt
from .base import MediaIO
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
"""
对视频帧进行缩放,将每一帧的大小调整为指定的高度和宽度。
Args:
frames (npt.NDArray, shape=(N, H, W, C)): 包含N个帧的三维数组其中H是高度W是宽度C是通道数。
所有帧都应该具有相同的通道数。
size (tuple[int, int], required): 一个元组,包含两个整数,分别表示目标高度和宽度。
Returns:
npt.NDArray, shape=(N, new_height, new_width, C): 返回一个新的三维数组,其中每一帧已经被缩放到指定的高度和宽度。
新数组的通道数与输入数组相同。
Raises:
None
"""
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty((num_frames, new_height, new_width, channels), dtype=frames.dtype)
# lazy import cv2 to avoid bothering users who only use text models
import cv2
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
return resized_frames
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
"""
对视频帧进行缩放,将每个帧的高度和宽度都乘以一个因子。
Args:
frames (npt.NDArray): 形状为THWC的四维numpy数组表示T个帧高度为H宽度为W通道数为C。
size_factor (float): 用于缩放视频帧的因子新的高度和宽度将分别是原来的高度和宽度的size_factor倍。
Returns:
npt.NDArray: 形状为Tnew_Hnew_WC的四维numpy数组表示T个帧高度为new_H宽度为new_W通道数为C。
其中new_H和new_W是根据size_factor计算出来的。
Raises:
None
"""
_, height, width, _ = frames.shape
new_height = int(height * size_factor)
new_width = int(width * size_factor)
return resize_video(frames, (new_height, new_width))
def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArray:
"""
从视频中随机选取指定数量的帧并返回一个包含这些帧的numpy数组。
Args:
frames (npt.NDArray): 形状为THWC的ndarray表示视频的所有帧其中T是帧的总数H、W是每个帧的高度和宽度C是通道数。
num_frames (int, optional): 要从视频中选取的帧数。如果设置为-1则将返回所有帧。默认为-1。
Returns:
npt.NDArray: 形状为num_framesHWC的ndarray表示选取的帧。如果num_frames=-1则返回原始的frames。
"""
total_frames = frames.shape[0]
if num_frames == -1:
return frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames
class VideoMediaIO(MediaIO[bytes]):
def __init__(self) -> None:
"""
初始化一个 VideoMediaIO 对象。
Args:
无。
Raises:
无。
Returns:
无。
"""
super().__init__()
def load_bytes(self, data: bytes) -> bytes:
"""
ERNIE-45-VL模型的前处理中包含抽帧操作如果将视频帧加载为npt.NDArray格式会丢失FPS信息因此目前
不对字节数据做任何操作。
Args:
data (bytes): 包含视频帧数据的字节对象。
Returns:
bytes字节数据原样返回。
Raises:
无。
"""
return data
def load_base64(self, media_type: str, data: str) -> bytes:
"""
加载 base64 编码的数据并返回bytes。
Args:
media_type (str): 媒体类型,目前不支持 "video/jpeg"
data (str): base64 编码的字符串数据。
Returns:
bytes, optional: 如果 media_type 不为 "video/jpeg",则返回字节数据。
Raises:
ValueError: 如果media_type是"video/jpeg"
"""
if media_type.lower() == "video/jpeg":
raise ValueError("Video in JPEG format is not supported")
return base64.b64decode(data)
def load_file(self, filepath: str) -> bytes:
"""
读取文件内容并返回bytes。
Args:
filepath (str): 文件路径,表示要读取的文件。
Returns:
bytes, optional: 返回字节数据,包含了文件内容。
Raises:
无。
"""
with open(filepath, "rb") as f:
data = f.read()
return data