[Feature] Add Qwen25-VL Processor (#3501)

* add qwen-2.5-vl processor

* add qwen25-vl processor

* add qwen25-vl processor

* add qwen25-vl processor

* add qwen25-vl processor position_ids

* add qwen25-vl processor

* add qwen25-vl processor

* position_ids

* add test for qwen25-vl

* organize comments

* formatted

* qwen_vl_processor

* add qwen_vl_processor unittest

* update model path

* update model path

* update qwen_vl_processor unittest

* add unittest and bug fix

* add unittest and bug fix

* Update fastdeploy/input/qwen_mm_processor/image_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/input/qwen_vl_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
lddfym
2025-08-22 16:49:42 +08:00
committed by GitHub
parent 5b66462f0e
commit 27666ee586
8 changed files with 1657 additions and 4 deletions

View File

@@ -15,9 +15,13 @@
"""
from .process import IDS_TYPE_FLAG, DataProcessor, fancy_print
from .process_video import read_video_decord
from .utils.video_utils import VideoReaderWrapper
__all__ = [
"DataProcessor",
"fancy_print",
"IDS_TYPE_FLAG",
"VideoReaderWrapper",
"read_video_decord",
]

View File

@@ -75,7 +75,10 @@ class InputPreprocessor:
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
if self.tool_parser:
tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser)
architectures = ModelConfig({"model": self.model_name_or_path}).architectures[0]
config = ModelConfig({"model": self.model_name_or_path})
architectures = config.architectures[0]
if not self.enable_mm:
if not ErnieArchitectures.contains_ernie_arch(architectures):
from fastdeploy.input.text_processor import DataProcessor
@@ -94,9 +97,7 @@ class InputPreprocessor:
tool_parser_obj=tool_parser_obj,
)
else:
if not ErnieArchitectures.contains_ernie_arch(architectures):
raise ValueError(f"Model {self.model_name_or_path} is not a valid Ernie4_5_VL model.")
else:
if ErnieArchitectures.contains_ernie_arch(architectures):
from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor
self.processor = ErnieMoEVLProcessor(
@@ -106,4 +107,14 @@ class InputPreprocessor:
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
)
else:
from fastdeploy.input.qwen_vl_processor import QwenVLProcessor
self.processor = QwenVLProcessor(
config=config,
model_name_or_path=self.model_name_or_path,
limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
)
return self.processor

View File

@@ -0,0 +1,22 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .process import IDS_TYPE_FLAG, DataProcessor
__all__ = [
"DataProcessor",
"IDS_TYPE_FLAG",
]

View File

@@ -0,0 +1,442 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import math
from typing import List, Optional, Union
import numpy as np
import paddle
import PIL
from paddleformers.transformers.feature_extraction_utils import BatchFeature
from paddleformers.transformers.image_processing_utils import BaseImageProcessor
from paddleformers.transformers.image_transforms import (
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from paddleformers.transformers.image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
)
from paddleformers.transformers.tokenizer_utils_base import TensorType
from PIL import Image
from fastdeploy.utils import data_processor_logger
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
VideoInput = Union[
List["PIL.Image.Image"],
"np.ndarray",
"paddle.Tensor",
List["np.ndarray"],
List["paddle.Tensor"],
List[List["PIL.Image.Image"]],
List[List["np.ndarray"]],
List[List["paddle.Tensor"]],
]
def round_by_factor(number: int, factor: int) -> int:
"""
Round number to nearest multiple of factor.
Args:
number: Input number to round
factor: Rounding factor
Returns:
int: Rounded number
"""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""
Round number up to nearest multiple of factor.
Args:
number: Input number to round
factor: Rounding factor
Returns:
int: Rounded number
"""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""
Round number down to nearest multiple of factor.
Args:
number: Input number to round
factor: Rounding factor
Returns:
int: Rounded number
"""
return math.floor(number / factor) * factor
def smart_resize(height: int, width: int, factor: int, min_pixels: int, max_pixels: int, max_ratio: int = 200):
"""
Smart image resizing that maintains aspect ratio and respects constraints.
Args:
height: Original image height
width: Original image width
factor: Patch size factor
min_pixels: Minimum allowed pixels
max_pixels: Maximum allowed pixels
max_ratio: Maximum allowed aspect ratio
Returns:
tuple: (new_height, new_width)
Raises:
ValueError: If calculated dimensions are invalid
"""
if max(height, width) / min(height, width) > max_ratio:
if height > width:
new_width = max(factor, round_by_factor(width, factor))
new_height = floor_by_factor(new_width * max_ratio, factor)
else:
new_height = max(factor, round_by_factor(height, factor))
new_width = floor_by_factor(new_height * max_ratio, factor)
data_processor_logger.info(
f"absolute aspect ratio must be smaller than {max_ratio}, got {max(height, width) / min(height, width)},\
resize to {max(new_height, new_width) / min(new_height, new_width)}"
)
height = new_height
width = new_width
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels:
raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}")
return h_bar, w_bar
def is_scaled_image(image: np.ndarray) -> bool:
"""
Check if image pixel values are already normalized to [0, 1] range.
Args:
image: Input image array
Returns:
bool: True if image is already scaled
"""
if image.dtype == np.uint8:
return False
# It's possible the image has pixel values in [0, 255] but is of floating type
return np.min(image) >= 0 and np.max(image) <= 1
class ImageProcessor(BaseImageProcessor):
"""
Adaptive image processor for dynamic image resizing and preprocessing.
This processor handles image resizing, rescaling, normalization and format conversion.
It dynamically adjusts image dimensions based on original size and specified constraints.
"""
def __init__(
self,
patch_size: int = 14,
merge_size: int = 2,
temporal_patch_size: int = 2,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
image_mean: Union[float, List[float]] = OPENAI_CLIP_MEAN,
image_std: Union[float, List[float]] = OPENAI_CLIP_STD,
rescale_factor: float = 1 / 255,
do_rescale: bool = True,
do_normalize: bool = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
**kwargs,
) -> None:
"""
Initialize image processor with configuration parameters.
Args:
patch_size (int): Spatial patch size for vision encoder
merge_size (int): Merge size between vision and LLM encoders
temporal_patch_size (int): Temporal patch size for video processing
min_pixels (int): Minimum allowed pixels in resized image
max_pixels (int): Maximum allowed pixels in resized image
image_mean (float/list): Mean values for normalization per channel
image_std (float/list): Std values for normalization per channel
rescale_factor (float): Scaling factor for pixel values (default 1/255)
do_rescale (bool): Whether to rescale images
do_normalize (bool): Whether to normalize images
resample: Resampling method for image resizing
**kwargs: Additional base class arguments
"""
super().__init__(**kwargs)
self.patch_size = patch_size
self.merge_size = merge_size
self.temporal_patch_size = temporal_patch_size
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.image_mean = image_mean
self.image_std = image_std
self.rescale_factor = rescale_factor
self.do_rescale = do_rescale
self.do_normalize = do_normalize
self.resample = resample
def _preprocess(
self,
images: Union[ImageInput, VideoInput],
min_pixels: int,
max_pixels: int,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
rescale_factor: float,
do_rescale: bool,
do_normalize: bool,
resample: PILImageResampling,
data_format: Optional[ChannelDimension],
input_data_format: Optional[Union[str, ChannelDimension]],
):
"""
Internal method for image preprocessing pipeline.
Args:
images: Input image or batch of images
min_pixels: Minimum allowed pixels in output
max_pixels: Maximum allowed pixels in output
image_mean: Normalization mean values
image_std: Normalization std values
rescale_factor: Pixel value scaling factor
do_rescale: Whether to rescale pixel values
do_normalize: Whether to normalize pixel values
resample: Resampling method
data_format: Output channel format
input_data_format: Input channel format
Returns:
tuple: (flatten_patches, grid_dimensions)
- flatten_patches: Flattened image patches
- grid_dimensions: Grid dimensions [t, h, w]
"""
images = make_list_of_images(images)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
data_processor_logger.warning(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
# Get original dimensions and calculate optimal resize dimensions
height, width = get_image_size(images[0], channel_dim=input_data_format)
resized_height, resized_width = smart_resize(
height,
width,
factor=self.patch_size * self.merge_size, # Combine patch and merge factors
min_pixels=min_pixels,
max_pixels=max_pixels,
)
processed_images = []
for image in images:
if height != resized_height or width != resized_width:
# Convert to uint8 before resizing to avoid double scaling
image = image.astype("uint8")
# Convert to PIL Image and resize
image = Image.fromarray(image)
image = resize(
image,
size=(resized_height, resized_width),
resample=resample,
data_format=input_data_format,
)
if do_rescale and do_normalize:
# Adjust mean and std for combined rescale+normalize
image_mean = np.array(image_mean, dtype=np.float32) * (1.0 / rescale_factor)
image_std = np.array(image_std, dtype=np.float32) * (1.0 / rescale_factor)
do_rescale = False # Skip separate rescale step
if do_rescale:
image = image.astype(np.float32)
image = rescale(image, scale=rescale_factor, data_format=input_data_format)
if do_normalize:
image = image.astype(np.float32)
image = normalize(
image=image,
mean=image_mean,
std=image_std,
data_format=input_data_format,
)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) # [C, H, W]
processed_images.append(image)
# Convert processed images to numpy array
patches = np.array(processed_images)
# Pad temporal dimension if needed
if patches.shape[0] % self.temporal_patch_size != 0:
repeats = np.repeat(
patches[-1][np.newaxis],
self.temporal_patch_size - (patches.shape[0] % self.temporal_patch_size),
axis=0,
)
patches = np.concatenate([patches, repeats], axis=0)
# Convert to channels-first format if needed
if data_format == ChannelDimension.LAST:
patches = patches.transpose([0, 3, 1, 2]) # [N, H, W, C] -> [N, C, H, W]
grid_t, channel = patches.shape[:2]
grid_t = grid_t // self.temporal_patch_size
grid_h, grid_w = (
resized_height // self.patch_size,
resized_width // self.patch_size,
)
# Reshape into hierarchical patch structure
patches = patches.reshape(
[
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.merge_size,
self.merge_size,
self.patch_size,
grid_w // self.merge_size,
self.merge_size,
self.patch_size,
]
)
# Reorder dimensions for better memory access pattern
# [grid_t, grid_h/merge_size, grid_w/merge_size, merge_size, merge_size, C, temporal_patch_size, psz, psz]
patches = patches.transpose([0, 3, 6, 4, 7, 2, 1, 5, 8])
flatten_patches = patches.reshape(
[
grid_t * grid_h * grid_w,
channel * self.temporal_patch_size * self.patch_size * self.patch_size,
]
)
return flatten_patches, np.array([grid_t, grid_h, grid_w])
def preprocess(
self,
images: Union[ImageInput, VideoInput],
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
rescale_factor: Optional[float] = None,
do_rescale: Optional[bool] = None,
do_normalize: Optional[bool] = None,
resample: Optional[PILImageResampling] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST,
):
"""
Main preprocessing method for images/videos.
Args:
images: Input image/video data
min_pixels: Override for minimum pixels
max_pixels: Override for maximum pixels
image_mean: Override for normalization mean
image_std: Override for normalization std
rescale_factor: Override for rescaling factor
do_rescale: Override for rescaling flag
do_normalize: Override for normalization flag
resample: Override for resampling method
return_tensors: Desired output tensor format
data_format: Output channel dimension format
input_data_format: Input channel dimension format
Returns:
BatchFeature: Processed features containing:
- pixel_values: Preprocessed pixel data
- grid_thw: Grid dimensions [temporal, height, width]
Raises:
ValueError: For invalid image types or dimensions
"""
min_pixels = min_pixels if min_pixels is not None else self.min_pixels
max_pixels = max_pixels if max_pixels is not None else self.max_pixels
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
if images is not None and not valid_images(images):
raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "paddle.Tensor.")
pixel_values, grid_thw = self._preprocess(
images,
min_pixels=min_pixels,
max_pixels=max_pixels,
image_mean=image_mean,
image_std=image_std,
rescale_factor=rescale_factor,
do_rescale=do_rescale,
do_normalize=do_normalize,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
)
data = {"pixel_values": pixel_values, "grid_thw": grid_thw}
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@@ -0,0 +1,505 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any, Dict, List, Tuple, Union
import numpy as np
from paddleformers.transformers import AutoTokenizer
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.mm_processor import IDS_TYPE_FLAG
from fastdeploy.utils import data_processor_logger
from .image_processor import ImageProcessor
from .process_video import read_frames, sample_frames
class DataProcessor:
"""
Processes multimodal inputs (text, images, videos) into model-ready formats.
Handles:
- Tokenization of text with special tokens for visual content
- Image and video preprocessing
- Generation of 3D positional embeddings
- Conversion of chat messages to model inputs
Attributes:
tokenizer: Text tokenizer instance
image_processor: Image/video preprocessor
image_token: Special token for image placeholders
video_token: Special token for video placeholders
vision_start: Token marking start of visual content
"""
def __init__(
self,
model_path: str,
video_min_frames: int = 4,
video_max_frames: int = 768,
tokens_per_second: int = 2,
tokenizer=None,
**kwargs,
) -> None:
"""
Initialize the data processor.
Args:
model_path: Path to pretrained model
video_min_frames: Minimum frames to sample from videos
video_max_frames: Maximum frames to sample from videos
tokens_per_second: Temporal resolution for positional embeddings
**kwargs: Additional configuration
"""
self.min_frames = video_min_frames
self.max_frames = video_max_frames
# Initialize tokenizer with left padding and fast tokenizer
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left", use_fast=True)
self.tokenizer.ignored_index = -100 # Set ignored index for loss calculation
else:
self.tokenizer = tokenizer
self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor
# Convolution sizes for patch aggregation
self.spatial_conv_size = self.image_processor.merge_size
self.temporal_conv_size = self.image_processor.temporal_patch_size
# Special tokens and IDs
self.image_token = "<|image_pad|>"
self.video_token = "<|video_pad|>"
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.vision_start = "<|vision_start|>"
self.vision_start_id = self.tokenizer.convert_tokens_to_ids(self.vision_start)
self.tokens_per_second = tokens_per_second
self.role_prefixes = {
"system": "",
"user": "User: ",
"bot": "Assistant: ",
"assistant": "Assistant: ",
}
def _pack_outputs(self, outputs):
"""
Pack and convert all output data into numpy arrays with appropriate types.
Args:
outputs (dict): Dictionary containing model outputs with keys:
- images: List of visual features
- grid_thw: List of spatial dimensions
- image_type_ids: List of content type indicators
- input_ids: List of token IDs
- token_type_ids: List of type identifiers
- position_ids: List of position embeddings
Returns:
dict: Processed outputs with all values converted to numpy arrays
"""
# Process visual outputs - stack if exists or set to None if empty
if not outputs["images"]:
outputs["images"] = None # No images case
outputs["grid_thw"] = None # No spatial dimensions
outputs["image_type_ids"] = None # No type IDs
else:
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
# Convert all outputs to numpy arrays with appropriate types
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
outputs["position_ids"] = np.concatenate(
outputs["position_ids"], axis=1, dtype=np.int64
) # Concatenate position IDs
return outputs
def text2ids(self, text, images=None, videos=None):
"""
Convert text with image/video placeholders into model inputs.
Args:
text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
images: List of PIL Images corresponding to image placeholders
videos: List of video data corresponding to video placeholders
Returns:
Dict containing:
- input_ids: Token IDs
- token_type_ids: Type identifiers (text/image/video)
- position_ids: 3D positional embeddings
- images: Preprocessed visual features
- grid_thw: Spatial/temporal dimensions
- image_type_ids: Visual content type (0=image, 1=video)
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
# Define placeholders and their lengths
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
VIDEO_PLACEHOLDER = "<|video@placeholder|>"
IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
# Initialize tracking variables for text parsing
st, image_idx, video_idx = 0, 0, 0 # Start position, image counter, video counter
while st < len(text):
# Find next image or video placeholder in text
image_pos = text.find(IMAGE_PLACEHOLDER, st)
image_pos = len(text) if image_pos == -1 else image_pos # Set to end if not found
video_pos = text.find(VIDEO_PLACEHOLDER, st)
video_pos = len(text) if video_pos == -1 else video_pos # Set to end if not found
ed = min(image_pos, video_pos) # End position is first placeholder found
self._add_text(text[st:ed], outputs)
if ed == len(text):
break
if ed == image_pos:
outputs["pic_cnt"] += 1
self._add_image(images[image_idx], outputs)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
if isinstance(item, dict):
frames, meta = self._load_and_process_video(item["video"], item)
else:
frames, meta = self._load_and_process_video(item, {})
outputs["video_cnt"] += 1
self._add_video(frames, meta, outputs)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
return self._pack_outputs(outputs)
def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat request with multimodal messages into model inputs.
Args:
request: Dictionary containing:
- messages: List of chat messages with text/image/video content
- request_id: Unique identifier for logging
tgts: Optional target sequences
Returns:
Dict with same structure as text2ids() output
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
# Parse and validate chat messages
messages = parse_chat_messages(request.get("messages"))
image_message_list = [] # Store visual content messages
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
# Normalize content to list format
content_items = msg.get("content")
if not isinstance(content_items, list):
content_items = [content_items]
# Collect all visual content items
for item in content_items:
if isinstance(item, dict) and item.get("type") in ["image", "video"]:
image_message_list.append(item)
raw_messages = request["messages"]
request["messages"] = messages
prompt_token_ids = self.apply_chat_template(request)
if len(prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
request["messages"] = raw_messages
vision_start_index = 0
vision_message_index = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] == self.vision_start_id:
self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs)
vision_start_index = i + 1
image_message = image_message_list[vision_message_index]
if image_message["type"] == "image":
img = image_message.get("image")
if img is None:
continue
outputs["pic_cnt"] += 1
self._add_image(img, outputs)
elif image_message["type"] == "video":
video_bytes = image_message.get("video")
if video_bytes is None:
continue
frames, meta = self._load_and_process_video(video_bytes, image_message)
outputs["video_cnt"] += 1
self._add_video(frames, meta, outputs)
vision_message_index += 1
self._add_text(prompt_token_ids[vision_start_index:], outputs)
return self._pack_outputs(outputs)
def _add_text(self, tokens, outputs: Dict) -> None:
"""
Add text tokens to model inputs dictionary.
Args:
tokens: Text string or already tokenized IDs
outputs: Dictionary accumulating model inputs
Note:
- Handles both raw text and pre-tokenized inputs
- Updates position IDs for 3D embeddings
"""
if not tokens:
return None
if isinstance(tokens, str):
tokens_str = self.tokenizer.tokenize(tokens)
tokens = self.tokenizer.convert_tokens_to_ids(tokens_str)
num_tokens = len(tokens)
outputs["input_ids"].extend(tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
"""
Generate 3D positional embeddings for text tokens.
Args:
start_pos: Starting position index
num_tokens: Number of tokens to generate positions for
Returns:
numpy.ndarray: 3D position IDs shaped (3, num_tokens)
"""
text_array = np.arange(num_tokens).reshape(1, -1)
text_index = np.broadcast_to(text_array, (3, num_tokens))
position = text_index + start_pos
return position
def _add_image(self, img, outputs: Dict) -> None:
"""
Add image data to model inputs dictionary.
Args:
img: PIL Image to process
outputs: Dictionary accumulating model inputs
Note:
- Preprocesses image and calculates spatial dimensions
- Adds image token IDs and type markers
- Generates appropriate position embeddings
"""
ret = self.image_processor.preprocess(images=[img.convert("RGB")])
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
grid_thw = ret["grid_thw"].tolist()
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
outputs["images"].append(ret["pixel_values"])
outputs["grid_thw"].append(grid_thw)
outputs["image_type_ids"].append(0)
t, h, w = grid_thw
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
def _add_video(self, frames, meta: Dict, outputs: Dict) -> None:
"""
Add video data to model inputs dictionary.
Args:
frames: Video frames as numpy array
meta: Video metadata containing fps/duration
outputs: Dictionary accumulating model inputs
Note:
- Handles temporal dimension in position embeddings
- Uses video-specific token IDs and type markers
"""
ret = self.image_processor.preprocess(images=frames)
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
grid_thw = ret["grid_thw"].tolist()
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["images"].append(ret["pixel_values"])
outputs["grid_thw"].append(grid_thw)
outputs["image_type_ids"].extend([1] * grid_thw[0])
fps = meta["fps"]
second_per_grid_t = self.temporal_conv_size / fps
t, h, w = grid_thw
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
def _compute_vision_positions(
self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
) -> np.ndarray:
"""
Generate 3D position IDs for visual inputs.
Args:
start_pos: Base position in sequence
t: Temporal patches (1 for images)
h: Height in patches
w: Width in patches
second_per_grid_t: Time per temporal patch
Returns:
np.ndarray: Position IDs for [t,h,w] dimensions
"""
h //= self.spatial_conv_size
w //= self.spatial_conv_size
tn = np.arange(t).reshape(-1, 1)
tn = np.broadcast_to(tn, (t, h * w))
tn = tn * int(second_per_grid_t) * self.tokens_per_second
t_index = tn.flatten()
hn = np.arange(h).reshape(1, -1, 1)
h_index = np.broadcast_to(hn, (t, h, w)).flatten()
wn = np.arange(w).reshape(1, 1, -1)
w_index = np.broadcast_to(wn, (t, h, w)).flatten()
position = np.stack([t_index, h_index, w_index]) + start_pos
return position
def _load_and_process_video(self, url: str, item: Dict) -> Tuple[np.ndarray, Dict]:
"""
Load and preprocess video into frames.
Args:
url: Video file path or bytes
item: Dictionary containing processing parameters
Returns:
tuple: (frames, metadata) where:
- frames: Processed video frames as numpy array
- metadata: Updated video metadata dictionary
"""
frames, meta = read_frames(url)
# Apply frame sampling if fps or target_frames specified
fps = item.get("fps", None)
num_frames = item.get("target_frames", None)
if fps is not None or num_frames is not None:
# Get frame sampling constraints
min_frames = item.get("min_frames", self.min_frames)
max_frames = item.get("max_frames", self.max_frames)
# Sample frames according to specifications
frames = sample_frames(
video=frames,
frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size
min_frames=min_frames,
max_frames=max_frames,
metadata=meta,
fps=fps,
num_frames=num_frames,
)
# Update metadata with new frame count and fps
meta["num_of_frame"] = frames.shape[0]
if fps is not None:
meta["fps"] = fps # Use specified fps
meta["duration"] = frames.shape[0] / fps
else:
meta["fps"] = frames.shape[0] / meta["duration"] # Calculate fps from sampled frames
return frames, meta
def apply_chat_template(self, request):
"""
Apply chat template to convert messages into token sequence.
Args:
request: Dictionary containing chat messages
Returns:
List of token IDs
Raises:
ValueError: If model doesn't support chat templates
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
raw_prompt = self.tokenizer.apply_chat_template(
request["messages"],
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
)
prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "")
request["text_after_process"] = raw_prompt
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}"
)
return token_ids

View File

@@ -0,0 +1,131 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import math
from typing import Optional, Union
import numpy as np
from PIL import Image
from fastdeploy.input.mm_processor import read_video_decord
def read_frames(video_path):
"""
Read and decode video frames from the given path
This function reads a video file and decodes it into individual RGB frames
using decord video reader. It also extracts video metadata including fps,
duration and frame count.
Args:
video_path (str): Path to the video file or bytes object containing video data
Returns:
tuple: A tuple containing:
frames (numpy.ndarray): Array of shape (num_frames, height, width, 3)
containing decoded RGB video frames
meta (dict): Dictionary containing video metadata:
- fps (float): Frames per second
- duration (float): Video duration in seconds
- num_of_frame (int): Total number of frames
- width (int): Frame width in pixels
- height (int): Frame height in pixels
Note:
- The function uses decord library for efficient video reading
- All frames are converted to RGB format regardless of input format
"""
reader, meta, _ = read_video_decord(video_path, save_to_disk=False)
frames = []
for i in range(meta["num_of_frame"]):
frame = reader[i].asnumpy()
image = Image.fromarray(frame, "RGB")
frames.append(image)
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
return frames, meta
def sample_frames(
video: np.ndarray,
frame_factor: int,
min_frames: int,
max_frames: int,
metadata: Optional[dict] = None,
fps: Optional[Union[int, float]] = None,
num_frames: Optional[int] = None,
):
"""
Sample frames from video according to specified criteria.
Args:
video: Input video frames as numpy array
frame_factor: Ensure sampled frames are multiples of this factor
min_frames: Minimum number of frames to sample
max_frames: Maximum number of frames to sample
metadata: Video metadata containing fps information
fps: Target frames per second for sampling
num_frames: Exact number of frames to sample
Returns:
np.ndarray: Sampled video frames
Raises:
ValueError: If both fps and num_frames are specified,
or if required metadata is missing,
or if requested frames exceed available frames
"""
if fps is not None and num_frames is not None:
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
if fps is None and num_frames is None:
return video
total_num_frames = video.shape[0]
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is not None:
num_frames = round(num_frames / frame_factor) * frame_factor
elif fps is not None:
if metadata is None:
raise ValueError(
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
)
max_frames = math.floor(min(max_frames, total_num_frames) / frame_factor) * frame_factor
num_frames = total_num_frames / metadata["fps"] * fps
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
num_frames = math.floor(num_frames / frame_factor) * frame_factor
if num_frames > total_num_frames:
raise ValueError(
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
"Decrease `num_frames` or `fps` for sampling."
)
# Calculate frame indices based on sampling strategy
if num_frames is not None:
# Evenly spaced sampling for target frame count
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32)
else:
# Keep all frames if no sampling requested
indices = np.arange(0, total_num_frames).astype(np.int32)
# Apply frame selection
video = video[indices]
return video

View File

@@ -0,0 +1,290 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import numpy as np
from fastdeploy.engine.request import Request
from fastdeploy.input.qwen_mm_processor import DataProcessor
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
from fastdeploy.utils import data_processor_logger
class QwenVLProcessor(TextProcessor):
"""
Qwen Vision-Language processor for handling multimodal inputs.
This processor extends TextProcessor to support:
- Image and video processing
- Multimodal feature extraction
- Tokenization and position encoding
- Request processing and model input generation
Attributes:
processor (DataProcessor): Underlying data processor instance
tokenizer: Text tokenizer instance
limit_mm_per_prompt (dict): Limits for multimodal inputs per prompt
"""
def __init__(
self,
config,
model_name_or_path,
limit_mm_per_prompt=None,
mm_processor_kwargs=None,
reasoning_parser_obj=None,
tool_parser_obj=None,
):
"""
Initialize QwenVLProcessor instance.
Args:
config: Model configuration object
model_name_or_path (str): Pretrained model name or path
limit_mm_per_prompt (dict, optional): Limits for multimodal inputs
mm_processor_kwargs (dict, optional): Multimodal processor arguments
reasoning_parser_obj: Reasoning parser instance
tool_parser_obj: Tool parser instance
"""
super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj)
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)
self.processor = DataProcessor(
model_path=model_name_or_path,
tokens_per_second=config.vision_config.tokens_per_second,
tokenizer=self.tokenizer,
**processor_kwargs,
)
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
def process_request(self, request, max_model_len=None, **kwargs):
"""
Process incoming request and generate model inputs.
Args:
request: Input request object
max_model_len (int, optional): Maximum context length
**kwargs: Additional processing parameters
Returns:
Request: Processed request with model inputs
"""
task = request.to_dict()
task["enable_thinking"] = kwargs.get("enable_thinking", False)
self.process_request_dict(task, max_model_len)
request = Request.from_dict(task)
request = self._apply_default_parameters(request)
return request
def _parse_processor_kwargs(self, kwargs):
"""
Parse and validate multimodal processor arguments.
Args:
kwargs (dict): Processor configuration arguments
Returns:
dict: Validated processor arguments
Raises:
ValueError: If arguments format is invalid
"""
if not kwargs:
return {}
try:
if not isinstance(kwargs, dict):
raise ValueError("mm-processor-kwargs must be a dictionary")
# Validate kwargs types against expected schema
data_processor_logger.info(f"Processing kwargs: {kwargs}")
expected_types = {
"video_max_frames": int, # Maximum video frames parameter
"video_min_frames": int, # Minimum video frames parameter
}
for key, value in kwargs.items():
if key in expected_types and not isinstance(value, expected_types[key]):
raise ValueError(
f"Invalid type for {key}: expected {expected_types[key].__name__}, got {type(value).__name__}"
)
return kwargs
except Exception as e:
data_processor_logger.warning(f"Invalid mm-processor-kwargs format: {e}")
return {}
def _parse_limits(self, limits):
"""
Parse and validate multimodal input limits.
Args:
limits (dict): Input limits configuration
Returns:
dict: Validated limits with defaults
Raises:
ValueError: If limits format is invalid
"""
DEFAULT_LIMITS = {"image": 1, "video": 1, "audio": 1}
if not limits:
return DEFAULT_LIMITS
try:
if not isinstance(limits, dict):
raise ValueError("limit-mm-per-prompt must be a dictionary")
data_processor_logger.info(f"_parse_limits:{limits}")
return {**DEFAULT_LIMITS, **limits}
except Exception as e:
data_processor_logger.warning(f"Invalid limit-mm-per-prompt format: {e}, using default limits")
return DEFAULT_LIMITS
def _check_mm_limits(self, item):
"""
Validate multimodal inputs against configured limits.
Args:
item: Input request item to validate
Raises:
ValueError: If input exceeds configured limits
"""
if isinstance(item, dict):
# 请求包含prompt和multi_modal_data
mm_data = item
else:
# 请求包含messages
mm_data = {"image": [], "video": []}
for message in item:
if isinstance(message.get("content"), list):
for part in message["content"]:
if part.get("type") in ["image_url", "image"]:
mm_data["image"].append(part)
elif part.get("type") in ["video_url", "video"]:
mm_data["video"].append(part)
for modality, data in mm_data.items():
if modality in self.limit_mm_per_prompt:
limit = self.limit_mm_per_prompt[modality]
if len(data) > limit:
raise ValueError(f"Too many {modality} items in prompt, " f"got {len(data)} but limit is {limit}")
def process_request_dict(self, request, max_model_len=None):
"""
Process request dictionary into model inputs.
Args:
request (dict): Input request dictionary
max_model_len (int, optional): Maximum context length
Returns:
dict: Processed request with model inputs
Raises:
ValueError: If request format is invalid
"""
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
if request.get("prompt"):
multimodal_data = request.get("multimodal_data")
if multimodal_data is None:
multimodal_data = {}
self._check_mm_limits(multimodal_data)
images = multimodal_data.get("image", None)
videos = multimodal_data.get("video", None)
outputs = self.processor.text2ids(request["prompt"], images, videos)
elif request.get("messages"):
messages = request["messages"]
self._check_mm_limits(messages)
outputs = self.processor.request2ids(request)
else:
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
metadata = request.get("metadata")
# Handle continuation of previous generation by appending existing tokens
if metadata and metadata.get("generated_token_ids"):
self.append_generated_tokens(outputs, metadata["generated_token_ids"])
outputs = self.pack_outputs(outputs)
request["prompt_token_ids"] = outputs["input_ids"].tolist()
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
request["multimodal_inputs"] = outputs
# Handle prompt truncation if exceeds model context length
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][
: max_model_len - 1
] # Leave space for at least 1 new token
# Set default max_tokens if not specified
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
data_processor_logger.info(f"Processed request {request}")
return request
def append_generated_tokens(self, outputs, generated_token_ids):
"""
Append generated tokens to existing outputs.
Args:
outputs: Current model outputs
generated_token_ids: Generated tokens to append
"""
out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]}
self.processor._add_text(generated_token_ids, out)
outputs["input_ids"] = np.concatenate(
[outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0
)
outputs["token_type_ids"] = np.concatenate(
[outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0
)
outputs["position_ids"] = np.concatenate(
[outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64
)
outputs["cur_position"] = out["cur_position"]
def pack_outputs(self, outputs):
"""
Prepare final output dictionary for model.
Args:
outputs: Intermediate processing outputs
Returns:
dict: Packed output dictionary with all required fields
"""
outputs["image_patch_id"] = self.processor.image_token_id
outputs["video_patch_id"] = self.processor.video_token_id
outputs["position_ids"] = outputs["position_ids"].transpose(1, 0)
return outputs

View File

@@ -0,0 +1,248 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
from PIL import Image
from fastdeploy.engine.request import Request
from fastdeploy.input.qwen_vl_processor import QwenVLProcessor
def mock_pil_image(height, width):
"""
Generate mock random RGB image
Args:
height: Image height in pixels
width: Image width in pixels
Returns:
PIL.Image object with random RGB data
"""
rgb_image = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
return Image.fromarray(rgb_image)
def mock_read_frames(height: int, width: int, nums_frame: int, fps: int):
"""
Generate mock video frames with metadata for testing purposes
Creates synthetic video data by generating random RGB frames and constructing
corresponding metadata to simulate real video processing.
Args:
height (int): Height of video frames in pixels
width (int): Width of video frames in pixels
nums_frame (int): Number of frames to generate
fps (int): Frames per second for the mock video
Returns:
tuple: A tuple containing:
frames (numpy.ndarray): Array of shape (nums_frame, height, width, 3)
containing randomly generated RGB frames
meta (dict): Dictionary with video metadata:
- fps (int): Frames per second (same as input)
- duration (float): Calculated duration in seconds (nums_frame/fps)
- num_of_frame (int): Number of frames (same as nums_frame input)
"""
frames = []
for _ in range(nums_frame):
frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
frames.append(frame)
frames = np.stack(frames, axis=0)
meta = {
"fps": fps,
"duration": nums_frame / fps,
"num_of_frame": nums_frame,
}
return frames, meta
class TestQwenVLProcessor(unittest.TestCase):
"""
Unit tests for Qwen Vision-Language Processor functionality
"""
def setUp(self):
"""
Initialize test case with:
- Mock configuration
- Patched message parsing and video processing methods
- QwenVLProcessor instance with test parameters
"""
config = MagicMock()
config.vision_config.tokens_per_second = 2
self.patcher_parse_image = patch(
"fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_image", return_value=mock_pil_image(480, 640)
)
self.patcher_parse_image.start()
self.patcher_parse_video = patch(
"fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_video", return_value=b"123"
)
self.patcher_parse_video.start()
self.patcher_read_frames = patch(
"fastdeploy.input.qwen_mm_processor.process.read_frames", return_value=mock_read_frames(480, 640, 5, 2)
)
self.patcher_read_frames.start()
mm_processor_kwargs = {
"video_max_frames": 10,
"video_min_frames": 1,
}
limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1}
model_name_or_path = "/ModelData/Qwen2.5-VL-7B-Instruct"
self.processor = QwenVLProcessor(
config=config,
model_name_or_path=model_name_or_path,
limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_kwargs=mm_processor_kwargs,
reasoning_parser_obj=None,
tool_parser_obj=None,
)
def tearDown(self) -> None:
"""Clean up test case by stopping all mock patches"""
self.patcher_read_frames.stop()
self.patcher_parse_image.stop()
self.patcher_parse_video.stop()
def test_process_request(self):
"""
Test processing of Request object with multimodal input
Validates:
1. Token ID lengths match position_ids and token_type_ids shapes
2. Image processing produces expected output dimensions
3. Video processing produces expected output dimensions
4. Correct counts for images (1) and videos (1)
"""
prompt = {
"request_id": "12345",
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "file://demo.jpeg"}},
{"type": "video_url", "video_url": {"url": "file://3_frame_video.mp4"}},
{"type": "text", "text": "Describe image and video."},
],
}
],
}
request = Request.from_dict(prompt)
result = self.processor.process_request(request, 1024 * 100)
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
self.assertEqual(
result.multimodal_inputs["images"].shape[0],
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
)
self.assertEqual(
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
)
self.assertEqual(result.multimodal_inputs["pic_cnt"], 1)
self.assertEqual(result.multimodal_inputs["video_cnt"], 1)
def test_process_request_dict(self):
"""
Test processing of dictionary-format request with multimodal input
Validates:
1. Token ID lengths match position_ids and token_type_ids shapes
2. Image processing produces expected output dimensions
3. Video processing produces expected output dimensions
4. Correct counts for images (1) and videos (1)
"""
num_generated_token_ids = 10
request = {
"request_id": "12345",
"metadata": {
"generated_token_ids": [1] * num_generated_token_ids,
},
"stop": ["stop", "eof"],
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "file://demo.jpeg"}},
{"type": "video_url", "video_url": {"url": "file://3_frame_video.mp4"}},
{"type": "text", "text": "Describe image and video."},
],
}
],
}
result = self.processor.process_request_dict(request, 1024 * 100)
self.assertEqual(result["prompt_token_ids_len"], result["multimodal_inputs"]["position_ids"].shape[0])
self.assertEqual(result["prompt_token_ids_len"], result["multimodal_inputs"]["token_type_ids"].shape[0])
self.assertEqual(
result["multimodal_inputs"]["images"].shape[0],
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
)
self.assertEqual(
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
)
self.assertEqual(result["multimodal_inputs"]["pic_cnt"], 1)
self.assertEqual(result["multimodal_inputs"]["video_cnt"], 1)
def test_prompt(self):
"""
Test processing of prompt with image and video placeholders
Validates:
1. Token ID lengths match position_ids and token_type_ids shapes
2. Image processing produces expected output dimensions
3. Video processing produces expected output dimensions
4. Correct counts for images (1) and videos (1)
"""
prompt = {
"request_id": "12345",
"prompt": "<|image@placeholder|><|video@placeholder|>Describe image and video.",
"multimodal_data": {
"image": [mock_pil_image(10, 2100)],
"video": [{"video": b"123", "fps": 5}],
},
}
request = Request.from_dict(prompt)
result = self.processor.process_request(request, 1024 * 100)
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
self.assertEqual(
result.multimodal_inputs["images"].shape[0],
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
)
self.assertEqual(
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
)
self.assertEqual(result.multimodal_inputs["pic_cnt"], 1)
self.assertEqual(result.multimodal_inputs["video_cnt"], 1)
if __name__ == "__main__":
unittest.main()