mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] Add Qwen25-VL Processor (#3501)
* add qwen-2.5-vl processor * add qwen25-vl processor * add qwen25-vl processor * add qwen25-vl processor * add qwen25-vl processor position_ids * add qwen25-vl processor * add qwen25-vl processor * position_ids * add test for qwen25-vl * organize comments * formatted * qwen_vl_processor * add qwen_vl_processor unittest * update model path * update model path * update qwen_vl_processor unittest * add unittest and bug fix * add unittest and bug fix * Update fastdeploy/input/qwen_mm_processor/image_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/input/qwen_vl_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -15,9 +15,13 @@
|
||||
"""
|
||||
|
||||
from .process import IDS_TYPE_FLAG, DataProcessor, fancy_print
|
||||
from .process_video import read_video_decord
|
||||
from .utils.video_utils import VideoReaderWrapper
|
||||
|
||||
__all__ = [
|
||||
"DataProcessor",
|
||||
"fancy_print",
|
||||
"IDS_TYPE_FLAG",
|
||||
"VideoReaderWrapper",
|
||||
"read_video_decord",
|
||||
]
|
||||
|
@@ -75,7 +75,10 @@ class InputPreprocessor:
|
||||
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
|
||||
if self.tool_parser:
|
||||
tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser)
|
||||
architectures = ModelConfig({"model": self.model_name_or_path}).architectures[0]
|
||||
|
||||
config = ModelConfig({"model": self.model_name_or_path})
|
||||
architectures = config.architectures[0]
|
||||
|
||||
if not self.enable_mm:
|
||||
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||
from fastdeploy.input.text_processor import DataProcessor
|
||||
@@ -94,9 +97,7 @@ class InputPreprocessor:
|
||||
tool_parser_obj=tool_parser_obj,
|
||||
)
|
||||
else:
|
||||
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||
raise ValueError(f"Model {self.model_name_or_path} is not a valid Ernie4_5_VL model.")
|
||||
else:
|
||||
if ErnieArchitectures.contains_ernie_arch(architectures):
|
||||
from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor
|
||||
|
||||
self.processor = ErnieMoEVLProcessor(
|
||||
@@ -106,4 +107,14 @@ class InputPreprocessor:
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
tool_parser_obj=tool_parser_obj,
|
||||
)
|
||||
else:
|
||||
from fastdeploy.input.qwen_vl_processor import QwenVLProcessor
|
||||
|
||||
self.processor = QwenVLProcessor(
|
||||
config=config,
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
)
|
||||
return self.processor
|
||||
|
22
fastdeploy/input/qwen_mm_processor/__init__.py
Normal file
22
fastdeploy/input/qwen_mm_processor/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .process import IDS_TYPE_FLAG, DataProcessor
|
||||
|
||||
__all__ = [
|
||||
"DataProcessor",
|
||||
"IDS_TYPE_FLAG",
|
||||
]
|
442
fastdeploy/input/qwen_mm_processor/image_processor.py
Normal file
442
fastdeploy/input/qwen_mm_processor/image_processor.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import PIL
|
||||
from paddleformers.transformers.feature_extraction_utils import BatchFeature
|
||||
from paddleformers.transformers.image_processing_utils import BaseImageProcessor
|
||||
from paddleformers.transformers.image_transforms import (
|
||||
normalize,
|
||||
rescale,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from paddleformers.transformers.image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from paddleformers.transformers.tokenizer_utils_base import TensorType
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
||||
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
MIN_PIXELS = 4 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
|
||||
|
||||
VideoInput = Union[
|
||||
List["PIL.Image.Image"],
|
||||
"np.ndarray",
|
||||
"paddle.Tensor",
|
||||
List["np.ndarray"],
|
||||
List["paddle.Tensor"],
|
||||
List[List["PIL.Image.Image"]],
|
||||
List[List["np.ndarray"]],
|
||||
List[List["paddle.Tensor"]],
|
||||
]
|
||||
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""
|
||||
Round number to nearest multiple of factor.
|
||||
|
||||
Args:
|
||||
number: Input number to round
|
||||
factor: Rounding factor
|
||||
|
||||
Returns:
|
||||
int: Rounded number
|
||||
"""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""
|
||||
Round number up to nearest multiple of factor.
|
||||
|
||||
Args:
|
||||
number: Input number to round
|
||||
factor: Rounding factor
|
||||
|
||||
Returns:
|
||||
int: Rounded number
|
||||
"""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""
|
||||
Round number down to nearest multiple of factor.
|
||||
|
||||
Args:
|
||||
number: Input number to round
|
||||
factor: Rounding factor
|
||||
|
||||
Returns:
|
||||
int: Rounded number
|
||||
"""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(height: int, width: int, factor: int, min_pixels: int, max_pixels: int, max_ratio: int = 200):
|
||||
"""
|
||||
Smart image resizing that maintains aspect ratio and respects constraints.
|
||||
|
||||
Args:
|
||||
height: Original image height
|
||||
width: Original image width
|
||||
factor: Patch size factor
|
||||
min_pixels: Minimum allowed pixels
|
||||
max_pixels: Maximum allowed pixels
|
||||
max_ratio: Maximum allowed aspect ratio
|
||||
|
||||
Returns:
|
||||
tuple: (new_height, new_width)
|
||||
|
||||
Raises:
|
||||
ValueError: If calculated dimensions are invalid
|
||||
"""
|
||||
if max(height, width) / min(height, width) > max_ratio:
|
||||
if height > width:
|
||||
new_width = max(factor, round_by_factor(width, factor))
|
||||
new_height = floor_by_factor(new_width * max_ratio, factor)
|
||||
else:
|
||||
new_height = max(factor, round_by_factor(height, factor))
|
||||
new_width = floor_by_factor(new_height * max_ratio, factor)
|
||||
|
||||
data_processor_logger.info(
|
||||
f"absolute aspect ratio must be smaller than {max_ratio}, got {max(height, width) / min(height, width)},\
|
||||
resize to {max(new_height, new_width) / min(new_height, new_width)}"
|
||||
)
|
||||
|
||||
height = new_height
|
||||
width = new_width
|
||||
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
|
||||
if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels:
|
||||
raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}")
|
||||
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def is_scaled_image(image: np.ndarray) -> bool:
|
||||
"""
|
||||
Check if image pixel values are already normalized to [0, 1] range.
|
||||
|
||||
Args:
|
||||
image: Input image array
|
||||
|
||||
Returns:
|
||||
bool: True if image is already scaled
|
||||
"""
|
||||
if image.dtype == np.uint8:
|
||||
return False
|
||||
|
||||
# It's possible the image has pixel values in [0, 255] but is of floating type
|
||||
return np.min(image) >= 0 and np.max(image) <= 1
|
||||
|
||||
|
||||
class ImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
Adaptive image processor for dynamic image resizing and preprocessing.
|
||||
|
||||
This processor handles image resizing, rescaling, normalization and format conversion.
|
||||
It dynamically adjusts image dimensions based on original size and specified constraints.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
merge_size: int = 2,
|
||||
temporal_patch_size: int = 2,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
image_mean: Union[float, List[float]] = OPENAI_CLIP_MEAN,
|
||||
image_std: Union[float, List[float]] = OPENAI_CLIP_STD,
|
||||
rescale_factor: float = 1 / 255,
|
||||
do_rescale: bool = True,
|
||||
do_normalize: bool = True,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize image processor with configuration parameters.
|
||||
|
||||
Args:
|
||||
patch_size (int): Spatial patch size for vision encoder
|
||||
merge_size (int): Merge size between vision and LLM encoders
|
||||
temporal_patch_size (int): Temporal patch size for video processing
|
||||
min_pixels (int): Minimum allowed pixels in resized image
|
||||
max_pixels (int): Maximum allowed pixels in resized image
|
||||
image_mean (float/list): Mean values for normalization per channel
|
||||
image_std (float/list): Std values for normalization per channel
|
||||
rescale_factor (float): Scaling factor for pixel values (default 1/255)
|
||||
do_rescale (bool): Whether to rescale images
|
||||
do_normalize (bool): Whether to normalize images
|
||||
resample: Resampling method for image resizing
|
||||
**kwargs: Additional base class arguments
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.patch_size = patch_size
|
||||
self.merge_size = merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
self.min_pixels = min_pixels
|
||||
self.max_pixels = max_pixels
|
||||
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_rescale = do_rescale
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
self.resample = resample
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: Union[ImageInput, VideoInput],
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
rescale_factor: float,
|
||||
do_rescale: bool,
|
||||
do_normalize: bool,
|
||||
resample: PILImageResampling,
|
||||
data_format: Optional[ChannelDimension],
|
||||
input_data_format: Optional[Union[str, ChannelDimension]],
|
||||
):
|
||||
"""
|
||||
Internal method for image preprocessing pipeline.
|
||||
|
||||
Args:
|
||||
images: Input image or batch of images
|
||||
min_pixels: Minimum allowed pixels in output
|
||||
max_pixels: Maximum allowed pixels in output
|
||||
image_mean: Normalization mean values
|
||||
image_std: Normalization std values
|
||||
rescale_factor: Pixel value scaling factor
|
||||
do_rescale: Whether to rescale pixel values
|
||||
do_normalize: Whether to normalize pixel values
|
||||
resample: Resampling method
|
||||
data_format: Output channel format
|
||||
input_data_format: Input channel format
|
||||
|
||||
Returns:
|
||||
tuple: (flatten_patches, grid_dimensions)
|
||||
- flatten_patches: Flattened image patches
|
||||
- grid_dimensions: Grid dimensions [t, h, w]
|
||||
"""
|
||||
images = make_list_of_images(images)
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
data_processor_logger.warning(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# Get original dimensions and calculate optimal resize dimensions
|
||||
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=self.patch_size * self.merge_size, # Combine patch and merge factors
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
|
||||
processed_images = []
|
||||
for image in images:
|
||||
if height != resized_height or width != resized_width:
|
||||
# Convert to uint8 before resizing to avoid double scaling
|
||||
image = image.astype("uint8")
|
||||
# Convert to PIL Image and resize
|
||||
image = Image.fromarray(image)
|
||||
image = resize(
|
||||
image,
|
||||
size=(resized_height, resized_width),
|
||||
resample=resample,
|
||||
data_format=input_data_format,
|
||||
)
|
||||
|
||||
if do_rescale and do_normalize:
|
||||
# Adjust mean and std for combined rescale+normalize
|
||||
image_mean = np.array(image_mean, dtype=np.float32) * (1.0 / rescale_factor)
|
||||
image_std = np.array(image_std, dtype=np.float32) * (1.0 / rescale_factor)
|
||||
do_rescale = False # Skip separate rescale step
|
||||
|
||||
if do_rescale:
|
||||
image = image.astype(np.float32)
|
||||
image = rescale(image, scale=rescale_factor, data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = image.astype(np.float32)
|
||||
image = normalize(
|
||||
image=image,
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
data_format=input_data_format,
|
||||
)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) # [C, H, W]
|
||||
processed_images.append(image)
|
||||
|
||||
# Convert processed images to numpy array
|
||||
patches = np.array(processed_images)
|
||||
|
||||
# Pad temporal dimension if needed
|
||||
if patches.shape[0] % self.temporal_patch_size != 0:
|
||||
repeats = np.repeat(
|
||||
patches[-1][np.newaxis],
|
||||
self.temporal_patch_size - (patches.shape[0] % self.temporal_patch_size),
|
||||
axis=0,
|
||||
)
|
||||
patches = np.concatenate([patches, repeats], axis=0)
|
||||
|
||||
# Convert to channels-first format if needed
|
||||
if data_format == ChannelDimension.LAST:
|
||||
patches = patches.transpose([0, 3, 1, 2]) # [N, H, W, C] -> [N, C, H, W]
|
||||
|
||||
grid_t, channel = patches.shape[:2]
|
||||
grid_t = grid_t // self.temporal_patch_size
|
||||
|
||||
grid_h, grid_w = (
|
||||
resized_height // self.patch_size,
|
||||
resized_width // self.patch_size,
|
||||
)
|
||||
# Reshape into hierarchical patch structure
|
||||
patches = patches.reshape(
|
||||
[
|
||||
grid_t,
|
||||
self.temporal_patch_size,
|
||||
channel,
|
||||
grid_h // self.merge_size,
|
||||
self.merge_size,
|
||||
self.patch_size,
|
||||
grid_w // self.merge_size,
|
||||
self.merge_size,
|
||||
self.patch_size,
|
||||
]
|
||||
)
|
||||
# Reorder dimensions for better memory access pattern
|
||||
# [grid_t, grid_h/merge_size, grid_w/merge_size, merge_size, merge_size, C, temporal_patch_size, psz, psz]
|
||||
patches = patches.transpose([0, 3, 6, 4, 7, 2, 1, 5, 8])
|
||||
|
||||
flatten_patches = patches.reshape(
|
||||
[
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * self.temporal_patch_size * self.patch_size * self.patch_size,
|
||||
]
|
||||
)
|
||||
|
||||
return flatten_patches, np.array([grid_t, grid_h, grid_w])
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: Union[ImageInput, VideoInput],
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
resample: Optional[PILImageResampling] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST,
|
||||
):
|
||||
"""
|
||||
Main preprocessing method for images/videos.
|
||||
|
||||
Args:
|
||||
images: Input image/video data
|
||||
min_pixels: Override for minimum pixels
|
||||
max_pixels: Override for maximum pixels
|
||||
image_mean: Override for normalization mean
|
||||
image_std: Override for normalization std
|
||||
rescale_factor: Override for rescaling factor
|
||||
do_rescale: Override for rescaling flag
|
||||
do_normalize: Override for normalization flag
|
||||
resample: Override for resampling method
|
||||
return_tensors: Desired output tensor format
|
||||
data_format: Output channel dimension format
|
||||
input_data_format: Input channel dimension format
|
||||
|
||||
Returns:
|
||||
BatchFeature: Processed features containing:
|
||||
- pixel_values: Preprocessed pixel data
|
||||
- grid_thw: Grid dimensions [temporal, height, width]
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid image types or dimensions
|
||||
"""
|
||||
min_pixels = min_pixels if min_pixels is not None else self.min_pixels
|
||||
max_pixels = max_pixels if max_pixels is not None else self.max_pixels
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
resample = resample if resample is not None else self.resample
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "paddle.Tensor.")
|
||||
|
||||
pixel_values, grid_thw = self._preprocess(
|
||||
images,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
rescale_factor=rescale_factor,
|
||||
do_rescale=do_rescale,
|
||||
do_normalize=do_normalize,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
data = {"pixel_values": pixel_values, "grid_thw": grid_thw}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
505
fastdeploy/input/qwen_mm_processor/process.py
Normal file
505
fastdeploy/input/qwen_mm_processor/process.py
Normal file
@@ -0,0 +1,505 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from paddleformers.transformers import AutoTokenizer
|
||||
|
||||
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
|
||||
from fastdeploy.input.mm_processor import IDS_TYPE_FLAG
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
from .image_processor import ImageProcessor
|
||||
from .process_video import read_frames, sample_frames
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""
|
||||
Processes multimodal inputs (text, images, videos) into model-ready formats.
|
||||
|
||||
Handles:
|
||||
- Tokenization of text with special tokens for visual content
|
||||
- Image and video preprocessing
|
||||
- Generation of 3D positional embeddings
|
||||
- Conversion of chat messages to model inputs
|
||||
|
||||
Attributes:
|
||||
tokenizer: Text tokenizer instance
|
||||
image_processor: Image/video preprocessor
|
||||
image_token: Special token for image placeholders
|
||||
video_token: Special token for video placeholders
|
||||
vision_start: Token marking start of visual content
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
video_min_frames: int = 4,
|
||||
video_max_frames: int = 768,
|
||||
tokens_per_second: int = 2,
|
||||
tokenizer=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the data processor.
|
||||
|
||||
Args:
|
||||
model_path: Path to pretrained model
|
||||
video_min_frames: Minimum frames to sample from videos
|
||||
video_max_frames: Maximum frames to sample from videos
|
||||
tokens_per_second: Temporal resolution for positional embeddings
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
self.min_frames = video_min_frames
|
||||
self.max_frames = video_max_frames
|
||||
|
||||
# Initialize tokenizer with left padding and fast tokenizer
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left", use_fast=True)
|
||||
self.tokenizer.ignored_index = -100 # Set ignored index for loss calculation
|
||||
else:
|
||||
self.tokenizer = tokenizer
|
||||
self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor
|
||||
|
||||
# Convolution sizes for patch aggregation
|
||||
self.spatial_conv_size = self.image_processor.merge_size
|
||||
self.temporal_conv_size = self.image_processor.temporal_patch_size
|
||||
|
||||
# Special tokens and IDs
|
||||
self.image_token = "<|image_pad|>"
|
||||
self.video_token = "<|video_pad|>"
|
||||
|
||||
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
|
||||
self.vision_start = "<|vision_start|>"
|
||||
self.vision_start_id = self.tokenizer.convert_tokens_to_ids(self.vision_start)
|
||||
|
||||
self.tokens_per_second = tokens_per_second
|
||||
|
||||
self.role_prefixes = {
|
||||
"system": "",
|
||||
"user": "User: ",
|
||||
"bot": "Assistant: ",
|
||||
"assistant": "Assistant: ",
|
||||
}
|
||||
|
||||
def _pack_outputs(self, outputs):
|
||||
"""
|
||||
Pack and convert all output data into numpy arrays with appropriate types.
|
||||
|
||||
Args:
|
||||
outputs (dict): Dictionary containing model outputs with keys:
|
||||
- images: List of visual features
|
||||
- grid_thw: List of spatial dimensions
|
||||
- image_type_ids: List of content type indicators
|
||||
- input_ids: List of token IDs
|
||||
- token_type_ids: List of type identifiers
|
||||
- position_ids: List of position embeddings
|
||||
|
||||
Returns:
|
||||
dict: Processed outputs with all values converted to numpy arrays
|
||||
"""
|
||||
# Process visual outputs - stack if exists or set to None if empty
|
||||
if not outputs["images"]:
|
||||
outputs["images"] = None # No images case
|
||||
outputs["grid_thw"] = None # No spatial dimensions
|
||||
outputs["image_type_ids"] = None # No type IDs
|
||||
else:
|
||||
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
|
||||
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
|
||||
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
|
||||
|
||||
# Convert all outputs to numpy arrays with appropriate types
|
||||
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
|
||||
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
|
||||
outputs["position_ids"] = np.concatenate(
|
||||
outputs["position_ids"], axis=1, dtype=np.int64
|
||||
) # Concatenate position IDs
|
||||
return outputs
|
||||
|
||||
def text2ids(self, text, images=None, videos=None):
|
||||
"""
|
||||
Convert text with image/video placeholders into model inputs.
|
||||
|
||||
Args:
|
||||
text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
|
||||
images: List of PIL Images corresponding to image placeholders
|
||||
videos: List of video data corresponding to video placeholders
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- input_ids: Token IDs
|
||||
- token_type_ids: Type identifiers (text/image/video)
|
||||
- position_ids: 3D positional embeddings
|
||||
- images: Preprocessed visual features
|
||||
- grid_thw: Spatial/temporal dimensions
|
||||
- image_type_ids: Visual content type (0=image, 1=video)
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
}
|
||||
|
||||
# Define placeholders and their lengths
|
||||
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
|
||||
VIDEO_PLACEHOLDER = "<|video@placeholder|>"
|
||||
IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
|
||||
VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
|
||||
|
||||
# Initialize tracking variables for text parsing
|
||||
st, image_idx, video_idx = 0, 0, 0 # Start position, image counter, video counter
|
||||
while st < len(text):
|
||||
# Find next image or video placeholder in text
|
||||
image_pos = text.find(IMAGE_PLACEHOLDER, st)
|
||||
image_pos = len(text) if image_pos == -1 else image_pos # Set to end if not found
|
||||
video_pos = text.find(VIDEO_PLACEHOLDER, st)
|
||||
video_pos = len(text) if video_pos == -1 else video_pos # Set to end if not found
|
||||
ed = min(image_pos, video_pos) # End position is first placeholder found
|
||||
|
||||
self._add_text(text[st:ed], outputs)
|
||||
if ed == len(text):
|
||||
break
|
||||
|
||||
if ed == image_pos:
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(images[image_idx], outputs)
|
||||
image_idx += 1
|
||||
st = ed + IMAGE_PLACEHOLDER_LEN
|
||||
else:
|
||||
item = videos[video_idx]
|
||||
if isinstance(item, dict):
|
||||
frames, meta = self._load_and_process_video(item["video"], item)
|
||||
else:
|
||||
frames, meta = self._load_and_process_video(item, {})
|
||||
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, meta, outputs)
|
||||
video_idx += 1
|
||||
st = ed + VIDEO_PLACEHOLDER_LEN
|
||||
|
||||
return self._pack_outputs(outputs)
|
||||
|
||||
def request2ids(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
"""
|
||||
Convert chat request with multimodal messages into model inputs.
|
||||
|
||||
Args:
|
||||
request: Dictionary containing:
|
||||
- messages: List of chat messages with text/image/video content
|
||||
- request_id: Unique identifier for logging
|
||||
tgts: Optional target sequences
|
||||
|
||||
Returns:
|
||||
Dict with same structure as text2ids() output
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
}
|
||||
|
||||
# Parse and validate chat messages
|
||||
messages = parse_chat_messages(request.get("messages"))
|
||||
image_message_list = [] # Store visual content messages
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
assert role in self.role_prefixes, f"Unsupported role: {role}"
|
||||
|
||||
# Normalize content to list format
|
||||
content_items = msg.get("content")
|
||||
if not isinstance(content_items, list):
|
||||
content_items = [content_items]
|
||||
|
||||
# Collect all visual content items
|
||||
for item in content_items:
|
||||
if isinstance(item, dict) and item.get("type") in ["image", "video"]:
|
||||
image_message_list.append(item)
|
||||
|
||||
raw_messages = request["messages"]
|
||||
request["messages"] = messages
|
||||
|
||||
prompt_token_ids = self.apply_chat_template(request)
|
||||
if len(prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
request["messages"] = raw_messages
|
||||
|
||||
vision_start_index = 0
|
||||
vision_message_index = 0
|
||||
for i in range(len(prompt_token_ids)):
|
||||
if prompt_token_ids[i] == self.vision_start_id:
|
||||
self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs)
|
||||
|
||||
vision_start_index = i + 1
|
||||
image_message = image_message_list[vision_message_index]
|
||||
|
||||
if image_message["type"] == "image":
|
||||
img = image_message.get("image")
|
||||
if img is None:
|
||||
continue
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(img, outputs)
|
||||
|
||||
elif image_message["type"] == "video":
|
||||
video_bytes = image_message.get("video")
|
||||
if video_bytes is None:
|
||||
continue
|
||||
frames, meta = self._load_and_process_video(video_bytes, image_message)
|
||||
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, meta, outputs)
|
||||
|
||||
vision_message_index += 1
|
||||
|
||||
self._add_text(prompt_token_ids[vision_start_index:], outputs)
|
||||
return self._pack_outputs(outputs)
|
||||
|
||||
def _add_text(self, tokens, outputs: Dict) -> None:
|
||||
"""
|
||||
Add text tokens to model inputs dictionary.
|
||||
|
||||
Args:
|
||||
tokens: Text string or already tokenized IDs
|
||||
outputs: Dictionary accumulating model inputs
|
||||
|
||||
Note:
|
||||
- Handles both raw text and pre-tokenized inputs
|
||||
- Updates position IDs for 3D embeddings
|
||||
"""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
if isinstance(tokens, str):
|
||||
tokens_str = self.tokenizer.tokenize(tokens)
|
||||
tokens = self.tokenizer.convert_tokens_to_ids(tokens_str)
|
||||
|
||||
num_tokens = len(tokens)
|
||||
outputs["input_ids"].extend(tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
|
||||
|
||||
position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
|
||||
def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
|
||||
"""
|
||||
Generate 3D positional embeddings for text tokens.
|
||||
|
||||
Args:
|
||||
start_pos: Starting position index
|
||||
num_tokens: Number of tokens to generate positions for
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 3D position IDs shaped (3, num_tokens)
|
||||
"""
|
||||
text_array = np.arange(num_tokens).reshape(1, -1)
|
||||
text_index = np.broadcast_to(text_array, (3, num_tokens))
|
||||
position = text_index + start_pos
|
||||
return position
|
||||
|
||||
def _add_image(self, img, outputs: Dict) -> None:
|
||||
"""
|
||||
Add image data to model inputs dictionary.
|
||||
|
||||
Args:
|
||||
img: PIL Image to process
|
||||
outputs: Dictionary accumulating model inputs
|
||||
|
||||
Note:
|
||||
- Preprocesses image and calculates spatial dimensions
|
||||
- Adds image token IDs and type markers
|
||||
- Generates appropriate position embeddings
|
||||
"""
|
||||
ret = self.image_processor.preprocess(images=[img.convert("RGB")])
|
||||
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
|
||||
grid_thw = ret["grid_thw"].tolist()
|
||||
|
||||
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
outputs["images"].append(ret["pixel_values"])
|
||||
outputs["grid_thw"].append(grid_thw)
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
t, h, w = grid_thw
|
||||
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
|
||||
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
|
||||
def _add_video(self, frames, meta: Dict, outputs: Dict) -> None:
|
||||
"""
|
||||
Add video data to model inputs dictionary.
|
||||
|
||||
Args:
|
||||
frames: Video frames as numpy array
|
||||
meta: Video metadata containing fps/duration
|
||||
outputs: Dictionary accumulating model inputs
|
||||
|
||||
Note:
|
||||
- Handles temporal dimension in position embeddings
|
||||
- Uses video-specific token IDs and type markers
|
||||
"""
|
||||
ret = self.image_processor.preprocess(images=frames)
|
||||
|
||||
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
|
||||
grid_thw = ret["grid_thw"].tolist()
|
||||
|
||||
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
|
||||
outputs["images"].append(ret["pixel_values"])
|
||||
outputs["grid_thw"].append(grid_thw)
|
||||
outputs["image_type_ids"].extend([1] * grid_thw[0])
|
||||
|
||||
fps = meta["fps"]
|
||||
second_per_grid_t = self.temporal_conv_size / fps
|
||||
t, h, w = grid_thw
|
||||
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
||||
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
|
||||
def _compute_vision_positions(
|
||||
self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate 3D position IDs for visual inputs.
|
||||
|
||||
Args:
|
||||
start_pos: Base position in sequence
|
||||
t: Temporal patches (1 for images)
|
||||
h: Height in patches
|
||||
w: Width in patches
|
||||
second_per_grid_t: Time per temporal patch
|
||||
|
||||
Returns:
|
||||
np.ndarray: Position IDs for [t,h,w] dimensions
|
||||
"""
|
||||
h //= self.spatial_conv_size
|
||||
w //= self.spatial_conv_size
|
||||
|
||||
tn = np.arange(t).reshape(-1, 1)
|
||||
tn = np.broadcast_to(tn, (t, h * w))
|
||||
tn = tn * int(second_per_grid_t) * self.tokens_per_second
|
||||
t_index = tn.flatten()
|
||||
|
||||
hn = np.arange(h).reshape(1, -1, 1)
|
||||
h_index = np.broadcast_to(hn, (t, h, w)).flatten()
|
||||
|
||||
wn = np.arange(w).reshape(1, 1, -1)
|
||||
w_index = np.broadcast_to(wn, (t, h, w)).flatten()
|
||||
|
||||
position = np.stack([t_index, h_index, w_index]) + start_pos
|
||||
return position
|
||||
|
||||
def _load_and_process_video(self, url: str, item: Dict) -> Tuple[np.ndarray, Dict]:
|
||||
"""
|
||||
Load and preprocess video into frames.
|
||||
|
||||
Args:
|
||||
url: Video file path or bytes
|
||||
item: Dictionary containing processing parameters
|
||||
|
||||
Returns:
|
||||
tuple: (frames, metadata) where:
|
||||
- frames: Processed video frames as numpy array
|
||||
- metadata: Updated video metadata dictionary
|
||||
"""
|
||||
frames, meta = read_frames(url)
|
||||
|
||||
# Apply frame sampling if fps or target_frames specified
|
||||
fps = item.get("fps", None)
|
||||
num_frames = item.get("target_frames", None)
|
||||
|
||||
if fps is not None or num_frames is not None:
|
||||
# Get frame sampling constraints
|
||||
min_frames = item.get("min_frames", self.min_frames)
|
||||
max_frames = item.get("max_frames", self.max_frames)
|
||||
|
||||
# Sample frames according to specifications
|
||||
frames = sample_frames(
|
||||
video=frames,
|
||||
frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size
|
||||
min_frames=min_frames,
|
||||
max_frames=max_frames,
|
||||
metadata=meta,
|
||||
fps=fps,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
# Update metadata with new frame count and fps
|
||||
meta["num_of_frame"] = frames.shape[0]
|
||||
if fps is not None:
|
||||
meta["fps"] = fps # Use specified fps
|
||||
meta["duration"] = frames.shape[0] / fps
|
||||
else:
|
||||
meta["fps"] = frames.shape[0] / meta["duration"] # Calculate fps from sampled frames
|
||||
|
||||
return frames, meta
|
||||
|
||||
def apply_chat_template(self, request):
|
||||
"""
|
||||
Apply chat template to convert messages into token sequence.
|
||||
|
||||
Args:
|
||||
request: Dictionary containing chat messages
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
|
||||
Raises:
|
||||
ValueError: If model doesn't support chat templates
|
||||
"""
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
|
||||
raw_prompt = self.tokenizer.apply_chat_template(
|
||||
request["messages"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
)
|
||||
prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "")
|
||||
request["text_after_process"] = raw_prompt
|
||||
|
||||
tokens = self.tokenizer.tokenize(prompt_token_str)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
data_processor_logger.info(
|
||||
f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
return token_ids
|
131
fastdeploy/input/qwen_mm_processor/process_video.py
Normal file
131
fastdeploy/input/qwen_mm_processor/process_video.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.input.mm_processor import read_video_decord
|
||||
|
||||
|
||||
def read_frames(video_path):
|
||||
"""
|
||||
Read and decode video frames from the given path
|
||||
|
||||
This function reads a video file and decodes it into individual RGB frames
|
||||
using decord video reader. It also extracts video metadata including fps,
|
||||
duration and frame count.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to the video file or bytes object containing video data
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
frames (numpy.ndarray): Array of shape (num_frames, height, width, 3)
|
||||
containing decoded RGB video frames
|
||||
meta (dict): Dictionary containing video metadata:
|
||||
- fps (float): Frames per second
|
||||
- duration (float): Video duration in seconds
|
||||
- num_of_frame (int): Total number of frames
|
||||
- width (int): Frame width in pixels
|
||||
- height (int): Frame height in pixels
|
||||
|
||||
Note:
|
||||
- The function uses decord library for efficient video reading
|
||||
- All frames are converted to RGB format regardless of input format
|
||||
"""
|
||||
reader, meta, _ = read_video_decord(video_path, save_to_disk=False)
|
||||
|
||||
frames = []
|
||||
for i in range(meta["num_of_frame"]):
|
||||
frame = reader[i].asnumpy()
|
||||
image = Image.fromarray(frame, "RGB")
|
||||
frames.append(image)
|
||||
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
|
||||
return frames, meta
|
||||
|
||||
|
||||
def sample_frames(
|
||||
video: np.ndarray,
|
||||
frame_factor: int,
|
||||
min_frames: int,
|
||||
max_frames: int,
|
||||
metadata: Optional[dict] = None,
|
||||
fps: Optional[Union[int, float]] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Sample frames from video according to specified criteria.
|
||||
|
||||
Args:
|
||||
video: Input video frames as numpy array
|
||||
frame_factor: Ensure sampled frames are multiples of this factor
|
||||
min_frames: Minimum number of frames to sample
|
||||
max_frames: Maximum number of frames to sample
|
||||
metadata: Video metadata containing fps information
|
||||
fps: Target frames per second for sampling
|
||||
num_frames: Exact number of frames to sample
|
||||
|
||||
Returns:
|
||||
np.ndarray: Sampled video frames
|
||||
|
||||
Raises:
|
||||
ValueError: If both fps and num_frames are specified,
|
||||
or if required metadata is missing,
|
||||
or if requested frames exceed available frames
|
||||
"""
|
||||
if fps is not None and num_frames is not None:
|
||||
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
||||
|
||||
if fps is None and num_frames is None:
|
||||
return video
|
||||
|
||||
total_num_frames = video.shape[0]
|
||||
|
||||
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||
if num_frames is not None:
|
||||
num_frames = round(num_frames / frame_factor) * frame_factor
|
||||
elif fps is not None:
|
||||
if metadata is None:
|
||||
raise ValueError(
|
||||
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
|
||||
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
|
||||
)
|
||||
max_frames = math.floor(min(max_frames, total_num_frames) / frame_factor) * frame_factor
|
||||
num_frames = total_num_frames / metadata["fps"] * fps
|
||||
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
|
||||
num_frames = math.floor(num_frames / frame_factor) * frame_factor
|
||||
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
|
||||
"Decrease `num_frames` or `fps` for sampling."
|
||||
)
|
||||
|
||||
# Calculate frame indices based on sampling strategy
|
||||
if num_frames is not None:
|
||||
# Evenly spaced sampling for target frame count
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32)
|
||||
else:
|
||||
# Keep all frames if no sampling requested
|
||||
indices = np.arange(0, total_num_frames).astype(np.int32)
|
||||
|
||||
# Apply frame selection
|
||||
video = video[indices]
|
||||
|
||||
return video
|
290
fastdeploy/input/qwen_vl_processor.py
Normal file
290
fastdeploy/input/qwen_vl_processor.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.qwen_mm_processor import DataProcessor
|
||||
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
|
||||
class QwenVLProcessor(TextProcessor):
|
||||
"""
|
||||
Qwen Vision-Language processor for handling multimodal inputs.
|
||||
|
||||
This processor extends TextProcessor to support:
|
||||
- Image and video processing
|
||||
- Multimodal feature extraction
|
||||
- Tokenization and position encoding
|
||||
- Request processing and model input generation
|
||||
|
||||
Attributes:
|
||||
processor (DataProcessor): Underlying data processor instance
|
||||
tokenizer: Text tokenizer instance
|
||||
limit_mm_per_prompt (dict): Limits for multimodal inputs per prompt
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
model_name_or_path,
|
||||
limit_mm_per_prompt=None,
|
||||
mm_processor_kwargs=None,
|
||||
reasoning_parser_obj=None,
|
||||
tool_parser_obj=None,
|
||||
):
|
||||
"""
|
||||
Initialize QwenVLProcessor instance.
|
||||
|
||||
Args:
|
||||
config: Model configuration object
|
||||
model_name_or_path (str): Pretrained model name or path
|
||||
limit_mm_per_prompt (dict, optional): Limits for multimodal inputs
|
||||
mm_processor_kwargs (dict, optional): Multimodal processor arguments
|
||||
reasoning_parser_obj: Reasoning parser instance
|
||||
tool_parser_obj: Tool parser instance
|
||||
"""
|
||||
super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj)
|
||||
|
||||
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
|
||||
processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)
|
||||
self.processor = DataProcessor(
|
||||
model_path=model_name_or_path,
|
||||
tokens_per_second=config.vision_config.tokens_per_second,
|
||||
tokenizer=self.tokenizer,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Process incoming request and generate model inputs.
|
||||
|
||||
Args:
|
||||
request: Input request object
|
||||
max_model_len (int, optional): Maximum context length
|
||||
**kwargs: Additional processing parameters
|
||||
|
||||
Returns:
|
||||
Request: Processed request with model inputs
|
||||
"""
|
||||
task = request.to_dict()
|
||||
task["enable_thinking"] = kwargs.get("enable_thinking", False)
|
||||
self.process_request_dict(task, max_model_len)
|
||||
request = Request.from_dict(task)
|
||||
request = self._apply_default_parameters(request)
|
||||
return request
|
||||
|
||||
def _parse_processor_kwargs(self, kwargs):
|
||||
"""
|
||||
Parse and validate multimodal processor arguments.
|
||||
|
||||
Args:
|
||||
kwargs (dict): Processor configuration arguments
|
||||
|
||||
Returns:
|
||||
dict: Validated processor arguments
|
||||
|
||||
Raises:
|
||||
ValueError: If arguments format is invalid
|
||||
"""
|
||||
if not kwargs:
|
||||
return {}
|
||||
|
||||
try:
|
||||
if not isinstance(kwargs, dict):
|
||||
raise ValueError("mm-processor-kwargs must be a dictionary")
|
||||
|
||||
# Validate kwargs types against expected schema
|
||||
data_processor_logger.info(f"Processing kwargs: {kwargs}")
|
||||
expected_types = {
|
||||
"video_max_frames": int, # Maximum video frames parameter
|
||||
"video_min_frames": int, # Minimum video frames parameter
|
||||
}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in expected_types and not isinstance(value, expected_types[key]):
|
||||
raise ValueError(
|
||||
f"Invalid type for {key}: expected {expected_types[key].__name__}, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
except Exception as e:
|
||||
data_processor_logger.warning(f"Invalid mm-processor-kwargs format: {e}")
|
||||
return {}
|
||||
|
||||
def _parse_limits(self, limits):
|
||||
"""
|
||||
Parse and validate multimodal input limits.
|
||||
|
||||
Args:
|
||||
limits (dict): Input limits configuration
|
||||
|
||||
Returns:
|
||||
dict: Validated limits with defaults
|
||||
|
||||
Raises:
|
||||
ValueError: If limits format is invalid
|
||||
"""
|
||||
DEFAULT_LIMITS = {"image": 1, "video": 1, "audio": 1}
|
||||
|
||||
if not limits:
|
||||
return DEFAULT_LIMITS
|
||||
|
||||
try:
|
||||
if not isinstance(limits, dict):
|
||||
raise ValueError("limit-mm-per-prompt must be a dictionary")
|
||||
data_processor_logger.info(f"_parse_limits:{limits}")
|
||||
return {**DEFAULT_LIMITS, **limits}
|
||||
except Exception as e:
|
||||
data_processor_logger.warning(f"Invalid limit-mm-per-prompt format: {e}, using default limits")
|
||||
return DEFAULT_LIMITS
|
||||
|
||||
def _check_mm_limits(self, item):
|
||||
"""
|
||||
Validate multimodal inputs against configured limits.
|
||||
|
||||
Args:
|
||||
item: Input request item to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If input exceeds configured limits
|
||||
"""
|
||||
if isinstance(item, dict):
|
||||
# 请求包含prompt和multi_modal_data
|
||||
mm_data = item
|
||||
else:
|
||||
# 请求包含messages
|
||||
mm_data = {"image": [], "video": []}
|
||||
|
||||
for message in item:
|
||||
if isinstance(message.get("content"), list):
|
||||
for part in message["content"]:
|
||||
if part.get("type") in ["image_url", "image"]:
|
||||
mm_data["image"].append(part)
|
||||
elif part.get("type") in ["video_url", "video"]:
|
||||
mm_data["video"].append(part)
|
||||
|
||||
for modality, data in mm_data.items():
|
||||
if modality in self.limit_mm_per_prompt:
|
||||
limit = self.limit_mm_per_prompt[modality]
|
||||
if len(data) > limit:
|
||||
raise ValueError(f"Too many {modality} items in prompt, " f"got {len(data)} but limit is {limit}")
|
||||
|
||||
def process_request_dict(self, request, max_model_len=None):
|
||||
"""
|
||||
Process request dictionary into model inputs.
|
||||
|
||||
Args:
|
||||
request (dict): Input request dictionary
|
||||
max_model_len (int, optional): Maximum context length
|
||||
|
||||
Returns:
|
||||
dict: Processed request with model inputs
|
||||
|
||||
Raises:
|
||||
ValueError: If request format is invalid
|
||||
"""
|
||||
|
||||
request = self._apply_default_parameters(request)
|
||||
if not request.get("eos_token_ids"):
|
||||
request["eos_token_ids"] = self.eos_token_ids
|
||||
|
||||
stop_sequences = request.get("stop", [])
|
||||
if stop_sequences:
|
||||
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
|
||||
request["stop_token_ids"] = stop_seqs
|
||||
request["stop_seqs_len"] = stop_seqs_len
|
||||
|
||||
if request.get("prompt"):
|
||||
multimodal_data = request.get("multimodal_data")
|
||||
if multimodal_data is None:
|
||||
multimodal_data = {}
|
||||
self._check_mm_limits(multimodal_data)
|
||||
images = multimodal_data.get("image", None)
|
||||
videos = multimodal_data.get("video", None)
|
||||
outputs = self.processor.text2ids(request["prompt"], images, videos)
|
||||
|
||||
elif request.get("messages"):
|
||||
messages = request["messages"]
|
||||
self._check_mm_limits(messages)
|
||||
outputs = self.processor.request2ids(request)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
||||
|
||||
metadata = request.get("metadata")
|
||||
# Handle continuation of previous generation by appending existing tokens
|
||||
if metadata and metadata.get("generated_token_ids"):
|
||||
self.append_generated_tokens(outputs, metadata["generated_token_ids"])
|
||||
outputs = self.pack_outputs(outputs)
|
||||
|
||||
request["prompt_token_ids"] = outputs["input_ids"].tolist()
|
||||
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
|
||||
request["multimodal_inputs"] = outputs
|
||||
|
||||
# Handle prompt truncation if exceeds model context length
|
||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][
|
||||
: max_model_len - 1
|
||||
] # Leave space for at least 1 new token
|
||||
|
||||
# Set default max_tokens if not specified
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
|
||||
def append_generated_tokens(self, outputs, generated_token_ids):
|
||||
"""
|
||||
Append generated tokens to existing outputs.
|
||||
|
||||
Args:
|
||||
outputs: Current model outputs
|
||||
generated_token_ids: Generated tokens to append
|
||||
"""
|
||||
out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]}
|
||||
self.processor._add_text(generated_token_ids, out)
|
||||
|
||||
outputs["input_ids"] = np.concatenate(
|
||||
[outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0
|
||||
)
|
||||
outputs["token_type_ids"] = np.concatenate(
|
||||
[outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0
|
||||
)
|
||||
outputs["position_ids"] = np.concatenate(
|
||||
[outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64
|
||||
)
|
||||
outputs["cur_position"] = out["cur_position"]
|
||||
|
||||
def pack_outputs(self, outputs):
|
||||
"""
|
||||
Prepare final output dictionary for model.
|
||||
|
||||
Args:
|
||||
outputs: Intermediate processing outputs
|
||||
|
||||
Returns:
|
||||
dict: Packed output dictionary with all required fields
|
||||
"""
|
||||
outputs["image_patch_id"] = self.processor.image_token_id
|
||||
outputs["video_patch_id"] = self.processor.video_token_id
|
||||
outputs["position_ids"] = outputs["position_ids"].transpose(1, 0)
|
||||
return outputs
|
248
tests/input/test_qwen_vl_processor.py
Normal file
248
tests/input/test_qwen_vl_processor.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.qwen_vl_processor import QwenVLProcessor
|
||||
|
||||
|
||||
def mock_pil_image(height, width):
|
||||
"""
|
||||
Generate mock random RGB image
|
||||
|
||||
Args:
|
||||
height: Image height in pixels
|
||||
width: Image width in pixels
|
||||
|
||||
Returns:
|
||||
PIL.Image object with random RGB data
|
||||
"""
|
||||
rgb_image = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
return Image.fromarray(rgb_image)
|
||||
|
||||
|
||||
def mock_read_frames(height: int, width: int, nums_frame: int, fps: int):
|
||||
"""
|
||||
Generate mock video frames with metadata for testing purposes
|
||||
|
||||
Creates synthetic video data by generating random RGB frames and constructing
|
||||
corresponding metadata to simulate real video processing.
|
||||
|
||||
Args:
|
||||
height (int): Height of video frames in pixels
|
||||
width (int): Width of video frames in pixels
|
||||
nums_frame (int): Number of frames to generate
|
||||
fps (int): Frames per second for the mock video
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
frames (numpy.ndarray): Array of shape (nums_frame, height, width, 3)
|
||||
containing randomly generated RGB frames
|
||||
meta (dict): Dictionary with video metadata:
|
||||
- fps (int): Frames per second (same as input)
|
||||
- duration (float): Calculated duration in seconds (nums_frame/fps)
|
||||
- num_of_frame (int): Number of frames (same as nums_frame input)
|
||||
"""
|
||||
frames = []
|
||||
for _ in range(nums_frame):
|
||||
frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
frames.append(frame)
|
||||
frames = np.stack(frames, axis=0)
|
||||
|
||||
meta = {
|
||||
"fps": fps,
|
||||
"duration": nums_frame / fps,
|
||||
"num_of_frame": nums_frame,
|
||||
}
|
||||
return frames, meta
|
||||
|
||||
|
||||
class TestQwenVLProcessor(unittest.TestCase):
|
||||
"""
|
||||
Unit tests for Qwen Vision-Language Processor functionality
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize test case with:
|
||||
- Mock configuration
|
||||
- Patched message parsing and video processing methods
|
||||
- QwenVLProcessor instance with test parameters
|
||||
"""
|
||||
config = MagicMock()
|
||||
config.vision_config.tokens_per_second = 2
|
||||
|
||||
self.patcher_parse_image = patch(
|
||||
"fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_image", return_value=mock_pil_image(480, 640)
|
||||
)
|
||||
self.patcher_parse_image.start()
|
||||
|
||||
self.patcher_parse_video = patch(
|
||||
"fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_video", return_value=b"123"
|
||||
)
|
||||
self.patcher_parse_video.start()
|
||||
|
||||
self.patcher_read_frames = patch(
|
||||
"fastdeploy.input.qwen_mm_processor.process.read_frames", return_value=mock_read_frames(480, 640, 5, 2)
|
||||
)
|
||||
self.patcher_read_frames.start()
|
||||
|
||||
mm_processor_kwargs = {
|
||||
"video_max_frames": 10,
|
||||
"video_min_frames": 1,
|
||||
}
|
||||
limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1}
|
||||
|
||||
model_name_or_path = "/ModelData/Qwen2.5-VL-7B-Instruct"
|
||||
self.processor = QwenVLProcessor(
|
||||
config=config,
|
||||
model_name_or_path=model_name_or_path,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
reasoning_parser_obj=None,
|
||||
tool_parser_obj=None,
|
||||
)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""Clean up test case by stopping all mock patches"""
|
||||
self.patcher_read_frames.stop()
|
||||
self.patcher_parse_image.stop()
|
||||
self.patcher_parse_video.stop()
|
||||
|
||||
def test_process_request(self):
|
||||
"""
|
||||
Test processing of Request object with multimodal input
|
||||
|
||||
Validates:
|
||||
1. Token ID lengths match position_ids and token_type_ids shapes
|
||||
2. Image processing produces expected output dimensions
|
||||
3. Video processing produces expected output dimensions
|
||||
4. Correct counts for images (1) and videos (1)
|
||||
"""
|
||||
prompt = {
|
||||
"request_id": "12345",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "file://demo.jpeg"}},
|
||||
{"type": "video_url", "video_url": {"url": "file://3_frame_video.mp4"}},
|
||||
{"type": "text", "text": "Describe image and video."},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
request = Request.from_dict(prompt)
|
||||
result = self.processor.process_request(request, 1024 * 100)
|
||||
|
||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
|
||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
|
||||
self.assertEqual(
|
||||
result.multimodal_inputs["images"].shape[0],
|
||||
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
|
||||
)
|
||||
self.assertEqual(
|
||||
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
|
||||
)
|
||||
self.assertEqual(result.multimodal_inputs["pic_cnt"], 1)
|
||||
self.assertEqual(result.multimodal_inputs["video_cnt"], 1)
|
||||
|
||||
def test_process_request_dict(self):
|
||||
"""
|
||||
Test processing of dictionary-format request with multimodal input
|
||||
|
||||
Validates:
|
||||
1. Token ID lengths match position_ids and token_type_ids shapes
|
||||
2. Image processing produces expected output dimensions
|
||||
3. Video processing produces expected output dimensions
|
||||
4. Correct counts for images (1) and videos (1)
|
||||
"""
|
||||
num_generated_token_ids = 10
|
||||
request = {
|
||||
"request_id": "12345",
|
||||
"metadata": {
|
||||
"generated_token_ids": [1] * num_generated_token_ids,
|
||||
},
|
||||
"stop": ["stop", "eof"],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "file://demo.jpeg"}},
|
||||
{"type": "video_url", "video_url": {"url": "file://3_frame_video.mp4"}},
|
||||
{"type": "text", "text": "Describe image and video."},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
result = self.processor.process_request_dict(request, 1024 * 100)
|
||||
|
||||
self.assertEqual(result["prompt_token_ids_len"], result["multimodal_inputs"]["position_ids"].shape[0])
|
||||
self.assertEqual(result["prompt_token_ids_len"], result["multimodal_inputs"]["token_type_ids"].shape[0])
|
||||
self.assertEqual(
|
||||
result["multimodal_inputs"]["images"].shape[0],
|
||||
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
|
||||
)
|
||||
self.assertEqual(
|
||||
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
|
||||
)
|
||||
self.assertEqual(result["multimodal_inputs"]["pic_cnt"], 1)
|
||||
self.assertEqual(result["multimodal_inputs"]["video_cnt"], 1)
|
||||
|
||||
def test_prompt(self):
|
||||
"""
|
||||
Test processing of prompt with image and video placeholders
|
||||
|
||||
Validates:
|
||||
1. Token ID lengths match position_ids and token_type_ids shapes
|
||||
2. Image processing produces expected output dimensions
|
||||
3. Video processing produces expected output dimensions
|
||||
4. Correct counts for images (1) and videos (1)
|
||||
"""
|
||||
prompt = {
|
||||
"request_id": "12345",
|
||||
"prompt": "<|image@placeholder|><|video@placeholder|>Describe image and video.",
|
||||
"multimodal_data": {
|
||||
"image": [mock_pil_image(10, 2100)],
|
||||
"video": [{"video": b"123", "fps": 5}],
|
||||
},
|
||||
}
|
||||
|
||||
request = Request.from_dict(prompt)
|
||||
result = self.processor.process_request(request, 1024 * 100)
|
||||
|
||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
|
||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
|
||||
self.assertEqual(
|
||||
result.multimodal_inputs["images"].shape[0],
|
||||
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
|
||||
)
|
||||
self.assertEqual(
|
||||
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
|
||||
)
|
||||
self.assertEqual(result.multimodal_inputs["pic_cnt"], 1)
|
||||
self.assertEqual(result.multimodal_inputs["video_cnt"], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user