diff --git a/fastdeploy/input/mm_processor/__init__.py b/fastdeploy/input/mm_processor/__init__.py index ba59bc165..95475194f 100644 --- a/fastdeploy/input/mm_processor/__init__.py +++ b/fastdeploy/input/mm_processor/__init__.py @@ -15,9 +15,13 @@ """ from .process import IDS_TYPE_FLAG, DataProcessor, fancy_print +from .process_video import read_video_decord +from .utils.video_utils import VideoReaderWrapper __all__ = [ "DataProcessor", "fancy_print", "IDS_TYPE_FLAG", + "VideoReaderWrapper", + "read_video_decord", ] diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index 62ed8d62d..55a052a03 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -75,7 +75,10 @@ class InputPreprocessor: reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser) if self.tool_parser: tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser) - architectures = ModelConfig({"model": self.model_name_or_path}).architectures[0] + + config = ModelConfig({"model": self.model_name_or_path}) + architectures = config.architectures[0] + if not self.enable_mm: if not ErnieArchitectures.contains_ernie_arch(architectures): from fastdeploy.input.text_processor import DataProcessor @@ -94,9 +97,7 @@ class InputPreprocessor: tool_parser_obj=tool_parser_obj, ) else: - if not ErnieArchitectures.contains_ernie_arch(architectures): - raise ValueError(f"Model {self.model_name_or_path} is not a valid Ernie4_5_VL model.") - else: + if ErnieArchitectures.contains_ernie_arch(architectures): from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor self.processor = ErnieMoEVLProcessor( @@ -106,4 +107,14 @@ class InputPreprocessor: reasoning_parser_obj=reasoning_parser_obj, tool_parser_obj=tool_parser_obj, ) + else: + from fastdeploy.input.qwen_vl_processor import QwenVLProcessor + + self.processor = QwenVLProcessor( + config=config, + model_name_or_path=self.model_name_or_path, + limit_mm_per_prompt=self.limit_mm_per_prompt, + mm_processor_kwargs=self.mm_processor_kwargs, + reasoning_parser_obj=reasoning_parser_obj, + ) return self.processor diff --git a/fastdeploy/input/qwen_mm_processor/__init__.py b/fastdeploy/input/qwen_mm_processor/__init__.py new file mode 100644 index 000000000..5a97e4186 --- /dev/null +++ b/fastdeploy/input/qwen_mm_processor/__init__.py @@ -0,0 +1,22 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from .process import IDS_TYPE_FLAG, DataProcessor + +__all__ = [ + "DataProcessor", + "IDS_TYPE_FLAG", +] diff --git a/fastdeploy/input/qwen_mm_processor/image_processor.py b/fastdeploy/input/qwen_mm_processor/image_processor.py new file mode 100644 index 000000000..c72a6abb7 --- /dev/null +++ b/fastdeploy/input/qwen_mm_processor/image_processor.py @@ -0,0 +1,442 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math +from typing import List, Optional, Union + +import numpy as np +import paddle +import PIL +from paddleformers.transformers.feature_extraction_utils import BatchFeature +from paddleformers.transformers.image_processing_utils import BaseImageProcessor +from paddleformers.transformers.image_transforms import ( + normalize, + rescale, + resize, + to_channel_dimension_format, +) +from paddleformers.transformers.image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + make_list_of_images, + to_numpy_array, + valid_images, +) +from paddleformers.transformers.tokenizer_utils_base import TensorType +from PIL import Image + +from fastdeploy.utils import data_processor_logger + +OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] +OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] + +MIN_PIXELS = 4 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 + + +VideoInput = Union[ + List["PIL.Image.Image"], + "np.ndarray", + "paddle.Tensor", + List["np.ndarray"], + List["paddle.Tensor"], + List[List["PIL.Image.Image"]], + List[List["np.ndarray"]], + List[List["paddle.Tensor"]], +] + + +def round_by_factor(number: int, factor: int) -> int: + """ + Round number to nearest multiple of factor. + + Args: + number: Input number to round + factor: Rounding factor + + Returns: + int: Rounded number + """ + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """ + Round number up to nearest multiple of factor. + + Args: + number: Input number to round + factor: Rounding factor + + Returns: + int: Rounded number + """ + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """ + Round number down to nearest multiple of factor. + + Args: + number: Input number to round + factor: Rounding factor + + Returns: + int: Rounded number + """ + return math.floor(number / factor) * factor + + +def smart_resize(height: int, width: int, factor: int, min_pixels: int, max_pixels: int, max_ratio: int = 200): + """ + Smart image resizing that maintains aspect ratio and respects constraints. + + Args: + height: Original image height + width: Original image width + factor: Patch size factor + min_pixels: Minimum allowed pixels + max_pixels: Maximum allowed pixels + max_ratio: Maximum allowed aspect ratio + + Returns: + tuple: (new_height, new_width) + + Raises: + ValueError: If calculated dimensions are invalid + """ + if max(height, width) / min(height, width) > max_ratio: + if height > width: + new_width = max(factor, round_by_factor(width, factor)) + new_height = floor_by_factor(new_width * max_ratio, factor) + else: + new_height = max(factor, round_by_factor(height, factor)) + new_width = floor_by_factor(new_height * max_ratio, factor) + + data_processor_logger.info( + f"absolute aspect ratio must be smaller than {max_ratio}, got {max(height, width) / min(height, width)},\ + resize to {max(new_height, new_width) / min(new_height, new_width)}" + ) + + height = new_height + width = new_width + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: + raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") + + return h_bar, w_bar + + +def is_scaled_image(image: np.ndarray) -> bool: + """ + Check if image pixel values are already normalized to [0, 1] range. + + Args: + image: Input image array + + Returns: + bool: True if image is already scaled + """ + if image.dtype == np.uint8: + return False + + # It's possible the image has pixel values in [0, 255] but is of floating type + return np.min(image) >= 0 and np.max(image) <= 1 + + +class ImageProcessor(BaseImageProcessor): + """ + Adaptive image processor for dynamic image resizing and preprocessing. + + This processor handles image resizing, rescaling, normalization and format conversion. + It dynamically adjusts image dimensions based on original size and specified constraints. + """ + + def __init__( + self, + patch_size: int = 14, + merge_size: int = 2, + temporal_patch_size: int = 2, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + image_mean: Union[float, List[float]] = OPENAI_CLIP_MEAN, + image_std: Union[float, List[float]] = OPENAI_CLIP_STD, + rescale_factor: float = 1 / 255, + do_rescale: bool = True, + do_normalize: bool = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + **kwargs, + ) -> None: + """ + Initialize image processor with configuration parameters. + + Args: + patch_size (int): Spatial patch size for vision encoder + merge_size (int): Merge size between vision and LLM encoders + temporal_patch_size (int): Temporal patch size for video processing + min_pixels (int): Minimum allowed pixels in resized image + max_pixels (int): Maximum allowed pixels in resized image + image_mean (float/list): Mean values for normalization per channel + image_std (float/list): Std values for normalization per channel + rescale_factor (float): Scaling factor for pixel values (default 1/255) + do_rescale (bool): Whether to rescale images + do_normalize (bool): Whether to normalize images + resample: Resampling method for image resizing + **kwargs: Additional base class arguments + """ + super().__init__(**kwargs) + self.patch_size = patch_size + self.merge_size = merge_size + self.temporal_patch_size = temporal_patch_size + + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.image_mean = image_mean + self.image_std = image_std + self.rescale_factor = rescale_factor + self.do_rescale = do_rescale + self.do_normalize = do_normalize + + self.resample = resample + + def _preprocess( + self, + images: Union[ImageInput, VideoInput], + min_pixels: int, + max_pixels: int, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + rescale_factor: float, + do_rescale: bool, + do_normalize: bool, + resample: PILImageResampling, + data_format: Optional[ChannelDimension], + input_data_format: Optional[Union[str, ChannelDimension]], + ): + """ + Internal method for image preprocessing pipeline. + + Args: + images: Input image or batch of images + min_pixels: Minimum allowed pixels in output + max_pixels: Maximum allowed pixels in output + image_mean: Normalization mean values + image_std: Normalization std values + rescale_factor: Pixel value scaling factor + do_rescale: Whether to rescale pixel values + do_normalize: Whether to normalize pixel values + resample: Resampling method + data_format: Output channel format + input_data_format: Input channel format + + Returns: + tuple: (flatten_patches, grid_dimensions) + - flatten_patches: Flattened image patches + - grid_dimensions: Grid dimensions [t, h, w] + """ + images = make_list_of_images(images) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + data_processor_logger.warning( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + # Get original dimensions and calculate optimal resize dimensions + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, # Combine patch and merge factors + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + + processed_images = [] + for image in images: + if height != resized_height or width != resized_width: + # Convert to uint8 before resizing to avoid double scaling + image = image.astype("uint8") + # Convert to PIL Image and resize + image = Image.fromarray(image) + image = resize( + image, + size=(resized_height, resized_width), + resample=resample, + data_format=input_data_format, + ) + + if do_rescale and do_normalize: + # Adjust mean and std for combined rescale+normalize + image_mean = np.array(image_mean, dtype=np.float32) * (1.0 / rescale_factor) + image_std = np.array(image_std, dtype=np.float32) * (1.0 / rescale_factor) + do_rescale = False # Skip separate rescale step + + if do_rescale: + image = image.astype(np.float32) + image = rescale(image, scale=rescale_factor, data_format=input_data_format) + + if do_normalize: + image = image.astype(np.float32) + image = normalize( + image=image, + mean=image_mean, + std=image_std, + data_format=input_data_format, + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) # [C, H, W] + processed_images.append(image) + + # Convert processed images to numpy array + patches = np.array(processed_images) + + # Pad temporal dimension if needed + if patches.shape[0] % self.temporal_patch_size != 0: + repeats = np.repeat( + patches[-1][np.newaxis], + self.temporal_patch_size - (patches.shape[0] % self.temporal_patch_size), + axis=0, + ) + patches = np.concatenate([patches, repeats], axis=0) + + # Convert to channels-first format if needed + if data_format == ChannelDimension.LAST: + patches = patches.transpose([0, 3, 1, 2]) # [N, H, W, C] -> [N, C, H, W] + + grid_t, channel = patches.shape[:2] + grid_t = grid_t // self.temporal_patch_size + + grid_h, grid_w = ( + resized_height // self.patch_size, + resized_width // self.patch_size, + ) + # Reshape into hierarchical patch structure + patches = patches.reshape( + [ + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ] + ) + # Reorder dimensions for better memory access pattern + # [grid_t, grid_h/merge_size, grid_w/merge_size, merge_size, merge_size, C, temporal_patch_size, psz, psz] + patches = patches.transpose([0, 3, 6, 4, 7, 2, 1, 5, 8]) + + flatten_patches = patches.reshape( + [ + grid_t * grid_h * grid_w, + channel * self.temporal_patch_size * self.patch_size * self.patch_size, + ] + ) + + return flatten_patches, np.array([grid_t, grid_h, grid_w]) + + def preprocess( + self, + images: Union[ImageInput, VideoInput], + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + rescale_factor: Optional[float] = None, + do_rescale: Optional[bool] = None, + do_normalize: Optional[bool] = None, + resample: Optional[PILImageResampling] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST, + ): + """ + Main preprocessing method for images/videos. + + Args: + images: Input image/video data + min_pixels: Override for minimum pixels + max_pixels: Override for maximum pixels + image_mean: Override for normalization mean + image_std: Override for normalization std + rescale_factor: Override for rescaling factor + do_rescale: Override for rescaling flag + do_normalize: Override for normalization flag + resample: Override for resampling method + return_tensors: Desired output tensor format + data_format: Output channel dimension format + input_data_format: Input channel dimension format + + Returns: + BatchFeature: Processed features containing: + - pixel_values: Preprocessed pixel data + - grid_thw: Grid dimensions [temporal, height, width] + + Raises: + ValueError: For invalid image types or dimensions + """ + min_pixels = min_pixels if min_pixels is not None else self.min_pixels + max_pixels = max_pixels if max_pixels is not None else self.max_pixels + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + resample = resample if resample is not None else self.resample + + if images is not None and not valid_images(images): + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "paddle.Tensor.") + + pixel_values, grid_thw = self._preprocess( + images, + min_pixels=min_pixels, + max_pixels=max_pixels, + image_mean=image_mean, + image_std=image_std, + rescale_factor=rescale_factor, + do_rescale=do_rescale, + do_normalize=do_normalize, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + ) + data = {"pixel_values": pixel_values, "grid_thw": grid_thw} + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/fastdeploy/input/qwen_mm_processor/process.py b/fastdeploy/input/qwen_mm_processor/process.py new file mode 100644 index 000000000..10e84ea7e --- /dev/null +++ b/fastdeploy/input/qwen_mm_processor/process.py @@ -0,0 +1,505 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from paddleformers.transformers import AutoTokenizer + +from fastdeploy.entrypoints.chat_utils import parse_chat_messages +from fastdeploy.input.mm_processor import IDS_TYPE_FLAG +from fastdeploy.utils import data_processor_logger + +from .image_processor import ImageProcessor +from .process_video import read_frames, sample_frames + + +class DataProcessor: + """ + Processes multimodal inputs (text, images, videos) into model-ready formats. + + Handles: + - Tokenization of text with special tokens for visual content + - Image and video preprocessing + - Generation of 3D positional embeddings + - Conversion of chat messages to model inputs + + Attributes: + tokenizer: Text tokenizer instance + image_processor: Image/video preprocessor + image_token: Special token for image placeholders + video_token: Special token for video placeholders + vision_start: Token marking start of visual content + """ + + def __init__( + self, + model_path: str, + video_min_frames: int = 4, + video_max_frames: int = 768, + tokens_per_second: int = 2, + tokenizer=None, + **kwargs, + ) -> None: + """ + Initialize the data processor. + + Args: + model_path: Path to pretrained model + video_min_frames: Minimum frames to sample from videos + video_max_frames: Maximum frames to sample from videos + tokens_per_second: Temporal resolution for positional embeddings + **kwargs: Additional configuration + """ + self.min_frames = video_min_frames + self.max_frames = video_max_frames + + # Initialize tokenizer with left padding and fast tokenizer + if tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left", use_fast=True) + self.tokenizer.ignored_index = -100 # Set ignored index for loss calculation + else: + self.tokenizer = tokenizer + self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor + + # Convolution sizes for patch aggregation + self.spatial_conv_size = self.image_processor.merge_size + self.temporal_conv_size = self.image_processor.temporal_patch_size + + # Special tokens and IDs + self.image_token = "<|image_pad|>" + self.video_token = "<|video_pad|>" + + self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) + self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token) + + self.vision_start = "<|vision_start|>" + self.vision_start_id = self.tokenizer.convert_tokens_to_ids(self.vision_start) + + self.tokens_per_second = tokens_per_second + + self.role_prefixes = { + "system": "", + "user": "User: ", + "bot": "Assistant: ", + "assistant": "Assistant: ", + } + + def _pack_outputs(self, outputs): + """ + Pack and convert all output data into numpy arrays with appropriate types. + + Args: + outputs (dict): Dictionary containing model outputs with keys: + - images: List of visual features + - grid_thw: List of spatial dimensions + - image_type_ids: List of content type indicators + - input_ids: List of token IDs + - token_type_ids: List of type identifiers + - position_ids: List of position embeddings + + Returns: + dict: Processed outputs with all values converted to numpy arrays + """ + # Process visual outputs - stack if exists or set to None if empty + if not outputs["images"]: + outputs["images"] = None # No images case + outputs["grid_thw"] = None # No spatial dimensions + outputs["image_type_ids"] = None # No type IDs + else: + outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically + outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions + outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array + + # Convert all outputs to numpy arrays with appropriate types + outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64 + outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64 + outputs["position_ids"] = np.concatenate( + outputs["position_ids"], axis=1, dtype=np.int64 + ) # Concatenate position IDs + return outputs + + def text2ids(self, text, images=None, videos=None): + """ + Convert text with image/video placeholders into model inputs. + + Args: + text: Input text with <|image@placeholder|> and <|video@placeholder|> markers + images: List of PIL Images corresponding to image placeholders + videos: List of video data corresponding to video placeholders + + Returns: + Dict containing: + - input_ids: Token IDs + - token_type_ids: Type identifiers (text/image/video) + - position_ids: 3D positional embeddings + - images: Preprocessed visual features + - grid_thw: Spatial/temporal dimensions + - image_type_ids: Visual content type (0=image, 1=video) + """ + + outputs = { + "input_ids": [], + "token_type_ids": [], + "position_ids": [], + "images": [], + "grid_thw": [], + "image_type_ids": [], + "labels": [], + "cur_position": 0, + "pic_cnt": 0, + "video_cnt": 0, + } + + # Define placeholders and their lengths + IMAGE_PLACEHOLDER = "<|image@placeholder|>" + VIDEO_PLACEHOLDER = "<|video@placeholder|>" + IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER) + VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER) + + # Initialize tracking variables for text parsing + st, image_idx, video_idx = 0, 0, 0 # Start position, image counter, video counter + while st < len(text): + # Find next image or video placeholder in text + image_pos = text.find(IMAGE_PLACEHOLDER, st) + image_pos = len(text) if image_pos == -1 else image_pos # Set to end if not found + video_pos = text.find(VIDEO_PLACEHOLDER, st) + video_pos = len(text) if video_pos == -1 else video_pos # Set to end if not found + ed = min(image_pos, video_pos) # End position is first placeholder found + + self._add_text(text[st:ed], outputs) + if ed == len(text): + break + + if ed == image_pos: + outputs["pic_cnt"] += 1 + self._add_image(images[image_idx], outputs) + image_idx += 1 + st = ed + IMAGE_PLACEHOLDER_LEN + else: + item = videos[video_idx] + if isinstance(item, dict): + frames, meta = self._load_and_process_video(item["video"], item) + else: + frames, meta = self._load_and_process_video(item, {}) + + outputs["video_cnt"] += 1 + self._add_video(frames, meta, outputs) + video_idx += 1 + st = ed + VIDEO_PLACEHOLDER_LEN + + return self._pack_outputs(outputs) + + def request2ids( + self, request: Dict[str, Any], tgts: List[str] = None + ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: + """ + Convert chat request with multimodal messages into model inputs. + + Args: + request: Dictionary containing: + - messages: List of chat messages with text/image/video content + - request_id: Unique identifier for logging + tgts: Optional target sequences + + Returns: + Dict with same structure as text2ids() output + """ + + outputs = { + "input_ids": [], + "token_type_ids": [], + "position_ids": [], + "images": [], + "grid_thw": [], + "image_type_ids": [], + "labels": [], + "cur_position": 0, + "pic_cnt": 0, + "video_cnt": 0, + } + + # Parse and validate chat messages + messages = parse_chat_messages(request.get("messages")) + image_message_list = [] # Store visual content messages + + for msg in messages: + role = msg.get("role") + assert role in self.role_prefixes, f"Unsupported role: {role}" + + # Normalize content to list format + content_items = msg.get("content") + if not isinstance(content_items, list): + content_items = [content_items] + + # Collect all visual content items + for item in content_items: + if isinstance(item, dict) and item.get("type") in ["image", "video"]: + image_message_list.append(item) + + raw_messages = request["messages"] + request["messages"] = messages + + prompt_token_ids = self.apply_chat_template(request) + if len(prompt_token_ids) == 0: + raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") + request["messages"] = raw_messages + + vision_start_index = 0 + vision_message_index = 0 + for i in range(len(prompt_token_ids)): + if prompt_token_ids[i] == self.vision_start_id: + self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs) + + vision_start_index = i + 1 + image_message = image_message_list[vision_message_index] + + if image_message["type"] == "image": + img = image_message.get("image") + if img is None: + continue + outputs["pic_cnt"] += 1 + self._add_image(img, outputs) + + elif image_message["type"] == "video": + video_bytes = image_message.get("video") + if video_bytes is None: + continue + frames, meta = self._load_and_process_video(video_bytes, image_message) + + outputs["video_cnt"] += 1 + self._add_video(frames, meta, outputs) + + vision_message_index += 1 + + self._add_text(prompt_token_ids[vision_start_index:], outputs) + return self._pack_outputs(outputs) + + def _add_text(self, tokens, outputs: Dict) -> None: + """ + Add text tokens to model inputs dictionary. + + Args: + tokens: Text string or already tokenized IDs + outputs: Dictionary accumulating model inputs + + Note: + - Handles both raw text and pre-tokenized inputs + - Updates position IDs for 3D embeddings + """ + if not tokens: + return None + + if isinstance(tokens, str): + tokens_str = self.tokenizer.tokenize(tokens) + tokens = self.tokenizer.convert_tokens_to_ids(tokens_str) + + num_tokens = len(tokens) + outputs["input_ids"].extend(tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens) + + position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens) + outputs["position_ids"].append(position_ids) + outputs["cur_position"] = position_ids.max() + 1 + + def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray: + """ + Generate 3D positional embeddings for text tokens. + + Args: + start_pos: Starting position index + num_tokens: Number of tokens to generate positions for + + Returns: + numpy.ndarray: 3D position IDs shaped (3, num_tokens) + """ + text_array = np.arange(num_tokens).reshape(1, -1) + text_index = np.broadcast_to(text_array, (3, num_tokens)) + position = text_index + start_pos + return position + + def _add_image(self, img, outputs: Dict) -> None: + """ + Add image data to model inputs dictionary. + + Args: + img: PIL Image to process + outputs: Dictionary accumulating model inputs + + Note: + - Preprocesses image and calculates spatial dimensions + - Adds image token IDs and type markers + - Generates appropriate position embeddings + """ + ret = self.image_processor.preprocess(images=[img.convert("RGB")]) + num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2 + grid_thw = ret["grid_thw"].tolist() + + outputs["input_ids"].extend([self.image_token_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) + + outputs["images"].append(ret["pixel_values"]) + outputs["grid_thw"].append(grid_thw) + outputs["image_type_ids"].append(0) + + t, h, w = grid_thw + position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0) + + outputs["position_ids"].append(position_ids) + outputs["cur_position"] = position_ids.max() + 1 + + def _add_video(self, frames, meta: Dict, outputs: Dict) -> None: + """ + Add video data to model inputs dictionary. + + Args: + frames: Video frames as numpy array + meta: Video metadata containing fps/duration + outputs: Dictionary accumulating model inputs + + Note: + - Handles temporal dimension in position embeddings + - Uses video-specific token IDs and type markers + """ + ret = self.image_processor.preprocess(images=frames) + + num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2 + grid_thw = ret["grid_thw"].tolist() + + outputs["input_ids"].extend([self.video_token_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) + + outputs["images"].append(ret["pixel_values"]) + outputs["grid_thw"].append(grid_thw) + outputs["image_type_ids"].extend([1] * grid_thw[0]) + + fps = meta["fps"] + second_per_grid_t = self.temporal_conv_size / fps + t, h, w = grid_thw + position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t) + + outputs["position_ids"].append(position_ids) + outputs["cur_position"] = position_ids.max() + 1 + + def _compute_vision_positions( + self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float + ) -> np.ndarray: + """ + Generate 3D position IDs for visual inputs. + + Args: + start_pos: Base position in sequence + t: Temporal patches (1 for images) + h: Height in patches + w: Width in patches + second_per_grid_t: Time per temporal patch + + Returns: + np.ndarray: Position IDs for [t,h,w] dimensions + """ + h //= self.spatial_conv_size + w //= self.spatial_conv_size + + tn = np.arange(t).reshape(-1, 1) + tn = np.broadcast_to(tn, (t, h * w)) + tn = tn * int(second_per_grid_t) * self.tokens_per_second + t_index = tn.flatten() + + hn = np.arange(h).reshape(1, -1, 1) + h_index = np.broadcast_to(hn, (t, h, w)).flatten() + + wn = np.arange(w).reshape(1, 1, -1) + w_index = np.broadcast_to(wn, (t, h, w)).flatten() + + position = np.stack([t_index, h_index, w_index]) + start_pos + return position + + def _load_and_process_video(self, url: str, item: Dict) -> Tuple[np.ndarray, Dict]: + """ + Load and preprocess video into frames. + + Args: + url: Video file path or bytes + item: Dictionary containing processing parameters + + Returns: + tuple: (frames, metadata) where: + - frames: Processed video frames as numpy array + - metadata: Updated video metadata dictionary + """ + frames, meta = read_frames(url) + + # Apply frame sampling if fps or target_frames specified + fps = item.get("fps", None) + num_frames = item.get("target_frames", None) + + if fps is not None or num_frames is not None: + # Get frame sampling constraints + min_frames = item.get("min_frames", self.min_frames) + max_frames = item.get("max_frames", self.max_frames) + + # Sample frames according to specifications + frames = sample_frames( + video=frames, + frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size + min_frames=min_frames, + max_frames=max_frames, + metadata=meta, + fps=fps, + num_frames=num_frames, + ) + + # Update metadata with new frame count and fps + meta["num_of_frame"] = frames.shape[0] + if fps is not None: + meta["fps"] = fps # Use specified fps + meta["duration"] = frames.shape[0] / fps + else: + meta["fps"] = frames.shape[0] / meta["duration"] # Calculate fps from sampled frames + + return frames, meta + + def apply_chat_template(self, request): + """ + Apply chat template to convert messages into token sequence. + + Args: + request: Dictionary containing chat messages + + Returns: + List of token IDs + + Raises: + ValueError: If model doesn't support chat templates + """ + if self.tokenizer.chat_template is None: + raise ValueError("This model does not support chat_template.") + + raw_prompt = self.tokenizer.apply_chat_template( + request["messages"], + tokenize=False, + add_generation_prompt=request.get("add_generation_prompt", True), + ) + prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "") + request["text_after_process"] = raw_prompt + + tokens = self.tokenizer.tokenize(prompt_token_str) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + data_processor_logger.info( + f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}" + ) + return token_ids diff --git a/fastdeploy/input/qwen_mm_processor/process_video.py b/fastdeploy/input/qwen_mm_processor/process_video.py new file mode 100644 index 000000000..808ffd76b --- /dev/null +++ b/fastdeploy/input/qwen_mm_processor/process_video.py @@ -0,0 +1,131 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math +from typing import Optional, Union + +import numpy as np +from PIL import Image + +from fastdeploy.input.mm_processor import read_video_decord + + +def read_frames(video_path): + """ + Read and decode video frames from the given path + + This function reads a video file and decodes it into individual RGB frames + using decord video reader. It also extracts video metadata including fps, + duration and frame count. + + Args: + video_path (str): Path to the video file or bytes object containing video data + + Returns: + tuple: A tuple containing: + frames (numpy.ndarray): Array of shape (num_frames, height, width, 3) + containing decoded RGB video frames + meta (dict): Dictionary containing video metadata: + - fps (float): Frames per second + - duration (float): Video duration in seconds + - num_of_frame (int): Total number of frames + - width (int): Frame width in pixels + - height (int): Frame height in pixels + + Note: + - The function uses decord library for efficient video reading + - All frames are converted to RGB format regardless of input format + """ + reader, meta, _ = read_video_decord(video_path, save_to_disk=False) + + frames = [] + for i in range(meta["num_of_frame"]): + frame = reader[i].asnumpy() + image = Image.fromarray(frame, "RGB") + frames.append(image) + frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) + return frames, meta + + +def sample_frames( + video: np.ndarray, + frame_factor: int, + min_frames: int, + max_frames: int, + metadata: Optional[dict] = None, + fps: Optional[Union[int, float]] = None, + num_frames: Optional[int] = None, +): + """ + Sample frames from video according to specified criteria. + + Args: + video: Input video frames as numpy array + frame_factor: Ensure sampled frames are multiples of this factor + min_frames: Minimum number of frames to sample + max_frames: Maximum number of frames to sample + metadata: Video metadata containing fps information + fps: Target frames per second for sampling + num_frames: Exact number of frames to sample + + Returns: + np.ndarray: Sampled video frames + + Raises: + ValueError: If both fps and num_frames are specified, + or if required metadata is missing, + or if requested frames exceed available frames + """ + if fps is not None and num_frames is not None: + raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") + + if fps is None and num_frames is None: + return video + + total_num_frames = video.shape[0] + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is not None: + num_frames = round(num_frames / frame_factor) * frame_factor + elif fps is not None: + if metadata is None: + raise ValueError( + "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " + "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video" + ) + max_frames = math.floor(min(max_frames, total_num_frames) / frame_factor) * frame_factor + num_frames = total_num_frames / metadata["fps"] * fps + num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames) + num_frames = math.floor(num_frames / frame_factor) * frame_factor + + if num_frames > total_num_frames: + raise ValueError( + f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. " + "Decrease `num_frames` or `fps` for sampling." + ) + + # Calculate frame indices based on sampling strategy + if num_frames is not None: + # Evenly spaced sampling for target frame count + indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32) + else: + # Keep all frames if no sampling requested + indices = np.arange(0, total_num_frames).astype(np.int32) + + # Apply frame selection + video = video[indices] + + return video diff --git a/fastdeploy/input/qwen_vl_processor.py b/fastdeploy/input/qwen_vl_processor.py new file mode 100644 index 000000000..8f6a8a9d7 --- /dev/null +++ b/fastdeploy/input/qwen_vl_processor.py @@ -0,0 +1,290 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import numpy as np + +from fastdeploy.engine.request import Request +from fastdeploy.input.qwen_mm_processor import DataProcessor +from fastdeploy.input.text_processor import DataProcessor as TextProcessor +from fastdeploy.utils import data_processor_logger + + +class QwenVLProcessor(TextProcessor): + """ + Qwen Vision-Language processor for handling multimodal inputs. + + This processor extends TextProcessor to support: + - Image and video processing + - Multimodal feature extraction + - Tokenization and position encoding + - Request processing and model input generation + + Attributes: + processor (DataProcessor): Underlying data processor instance + tokenizer: Text tokenizer instance + limit_mm_per_prompt (dict): Limits for multimodal inputs per prompt + """ + + def __init__( + self, + config, + model_name_or_path, + limit_mm_per_prompt=None, + mm_processor_kwargs=None, + reasoning_parser_obj=None, + tool_parser_obj=None, + ): + """ + Initialize QwenVLProcessor instance. + + Args: + config: Model configuration object + model_name_or_path (str): Pretrained model name or path + limit_mm_per_prompt (dict, optional): Limits for multimodal inputs + mm_processor_kwargs (dict, optional): Multimodal processor arguments + reasoning_parser_obj: Reasoning parser instance + tool_parser_obj: Tool parser instance + """ + super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj) + + data_processor_logger.info(f"model_name_or_path: {model_name_or_path}") + processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs) + self.processor = DataProcessor( + model_path=model_name_or_path, + tokens_per_second=config.vision_config.tokens_per_second, + tokenizer=self.tokenizer, + **processor_kwargs, + ) + + self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt) + + def process_request(self, request, max_model_len=None, **kwargs): + """ + Process incoming request and generate model inputs. + + Args: + request: Input request object + max_model_len (int, optional): Maximum context length + **kwargs: Additional processing parameters + + Returns: + Request: Processed request with model inputs + """ + task = request.to_dict() + task["enable_thinking"] = kwargs.get("enable_thinking", False) + self.process_request_dict(task, max_model_len) + request = Request.from_dict(task) + request = self._apply_default_parameters(request) + return request + + def _parse_processor_kwargs(self, kwargs): + """ + Parse and validate multimodal processor arguments. + + Args: + kwargs (dict): Processor configuration arguments + + Returns: + dict: Validated processor arguments + + Raises: + ValueError: If arguments format is invalid + """ + if not kwargs: + return {} + + try: + if not isinstance(kwargs, dict): + raise ValueError("mm-processor-kwargs must be a dictionary") + + # Validate kwargs types against expected schema + data_processor_logger.info(f"Processing kwargs: {kwargs}") + expected_types = { + "video_max_frames": int, # Maximum video frames parameter + "video_min_frames": int, # Minimum video frames parameter + } + + for key, value in kwargs.items(): + if key in expected_types and not isinstance(value, expected_types[key]): + raise ValueError( + f"Invalid type for {key}: expected {expected_types[key].__name__}, got {type(value).__name__}" + ) + + return kwargs + + except Exception as e: + data_processor_logger.warning(f"Invalid mm-processor-kwargs format: {e}") + return {} + + def _parse_limits(self, limits): + """ + Parse and validate multimodal input limits. + + Args: + limits (dict): Input limits configuration + + Returns: + dict: Validated limits with defaults + + Raises: + ValueError: If limits format is invalid + """ + DEFAULT_LIMITS = {"image": 1, "video": 1, "audio": 1} + + if not limits: + return DEFAULT_LIMITS + + try: + if not isinstance(limits, dict): + raise ValueError("limit-mm-per-prompt must be a dictionary") + data_processor_logger.info(f"_parse_limits:{limits}") + return {**DEFAULT_LIMITS, **limits} + except Exception as e: + data_processor_logger.warning(f"Invalid limit-mm-per-prompt format: {e}, using default limits") + return DEFAULT_LIMITS + + def _check_mm_limits(self, item): + """ + Validate multimodal inputs against configured limits. + + Args: + item: Input request item to validate + + Raises: + ValueError: If input exceeds configured limits + """ + if isinstance(item, dict): + # 请求包含prompt和multi_modal_data + mm_data = item + else: + # 请求包含messages + mm_data = {"image": [], "video": []} + + for message in item: + if isinstance(message.get("content"), list): + for part in message["content"]: + if part.get("type") in ["image_url", "image"]: + mm_data["image"].append(part) + elif part.get("type") in ["video_url", "video"]: + mm_data["video"].append(part) + + for modality, data in mm_data.items(): + if modality in self.limit_mm_per_prompt: + limit = self.limit_mm_per_prompt[modality] + if len(data) > limit: + raise ValueError(f"Too many {modality} items in prompt, " f"got {len(data)} but limit is {limit}") + + def process_request_dict(self, request, max_model_len=None): + """ + Process request dictionary into model inputs. + + Args: + request (dict): Input request dictionary + max_model_len (int, optional): Maximum context length + + Returns: + dict: Processed request with model inputs + + Raises: + ValueError: If request format is invalid + """ + + request = self._apply_default_parameters(request) + if not request.get("eos_token_ids"): + request["eos_token_ids"] = self.eos_token_ids + + stop_sequences = request.get("stop", []) + if stop_sequences: + stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) + request["stop_token_ids"] = stop_seqs + request["stop_seqs_len"] = stop_seqs_len + + if request.get("prompt"): + multimodal_data = request.get("multimodal_data") + if multimodal_data is None: + multimodal_data = {} + self._check_mm_limits(multimodal_data) + images = multimodal_data.get("image", None) + videos = multimodal_data.get("video", None) + outputs = self.processor.text2ids(request["prompt"], images, videos) + + elif request.get("messages"): + messages = request["messages"] + self._check_mm_limits(messages) + outputs = self.processor.request2ids(request) + + else: + raise ValueError(f"Request must contain 'prompt', or 'messages': {request}") + + metadata = request.get("metadata") + # Handle continuation of previous generation by appending existing tokens + if metadata and metadata.get("generated_token_ids"): + self.append_generated_tokens(outputs, metadata["generated_token_ids"]) + outputs = self.pack_outputs(outputs) + + request["prompt_token_ids"] = outputs["input_ids"].tolist() + request["prompt_token_ids_len"] = len(request["prompt_token_ids"]) + request["multimodal_inputs"] = outputs + + # Handle prompt truncation if exceeds model context length + if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len: + request["prompt_token_ids"] = request["prompt_token_ids"][ + : max_model_len - 1 + ] # Leave space for at least 1 new token + + # Set default max_tokens if not specified + if request.get("max_tokens") is None: + request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token + data_processor_logger.info(f"Processed request {request}") + + return request + + def append_generated_tokens(self, outputs, generated_token_ids): + """ + Append generated tokens to existing outputs. + + Args: + outputs: Current model outputs + generated_token_ids: Generated tokens to append + """ + out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]} + self.processor._add_text(generated_token_ids, out) + + outputs["input_ids"] = np.concatenate( + [outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0 + ) + outputs["token_type_ids"] = np.concatenate( + [outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0 + ) + outputs["position_ids"] = np.concatenate( + [outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64 + ) + outputs["cur_position"] = out["cur_position"] + + def pack_outputs(self, outputs): + """ + Prepare final output dictionary for model. + + Args: + outputs: Intermediate processing outputs + + Returns: + dict: Packed output dictionary with all required fields + """ + outputs["image_patch_id"] = self.processor.image_token_id + outputs["video_patch_id"] = self.processor.video_token_id + outputs["position_ids"] = outputs["position_ids"].transpose(1, 0) + return outputs diff --git a/tests/input/test_qwen_vl_processor.py b/tests/input/test_qwen_vl_processor.py new file mode 100644 index 000000000..6a3939245 --- /dev/null +++ b/tests/input/test_qwen_vl_processor.py @@ -0,0 +1,248 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +from PIL import Image + +from fastdeploy.engine.request import Request +from fastdeploy.input.qwen_vl_processor import QwenVLProcessor + + +def mock_pil_image(height, width): + """ + Generate mock random RGB image + + Args: + height: Image height in pixels + width: Image width in pixels + + Returns: + PIL.Image object with random RGB data + """ + rgb_image = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) + return Image.fromarray(rgb_image) + + +def mock_read_frames(height: int, width: int, nums_frame: int, fps: int): + """ + Generate mock video frames with metadata for testing purposes + + Creates synthetic video data by generating random RGB frames and constructing + corresponding metadata to simulate real video processing. + + Args: + height (int): Height of video frames in pixels + width (int): Width of video frames in pixels + nums_frame (int): Number of frames to generate + fps (int): Frames per second for the mock video + + Returns: + tuple: A tuple containing: + frames (numpy.ndarray): Array of shape (nums_frame, height, width, 3) + containing randomly generated RGB frames + meta (dict): Dictionary with video metadata: + - fps (int): Frames per second (same as input) + - duration (float): Calculated duration in seconds (nums_frame/fps) + - num_of_frame (int): Number of frames (same as nums_frame input) + """ + frames = [] + for _ in range(nums_frame): + frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) + frames.append(frame) + frames = np.stack(frames, axis=0) + + meta = { + "fps": fps, + "duration": nums_frame / fps, + "num_of_frame": nums_frame, + } + return frames, meta + + +class TestQwenVLProcessor(unittest.TestCase): + """ + Unit tests for Qwen Vision-Language Processor functionality + """ + + def setUp(self): + """ + Initialize test case with: + - Mock configuration + - Patched message parsing and video processing methods + - QwenVLProcessor instance with test parameters + """ + config = MagicMock() + config.vision_config.tokens_per_second = 2 + + self.patcher_parse_image = patch( + "fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_image", return_value=mock_pil_image(480, 640) + ) + self.patcher_parse_image.start() + + self.patcher_parse_video = patch( + "fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_video", return_value=b"123" + ) + self.patcher_parse_video.start() + + self.patcher_read_frames = patch( + "fastdeploy.input.qwen_mm_processor.process.read_frames", return_value=mock_read_frames(480, 640, 5, 2) + ) + self.patcher_read_frames.start() + + mm_processor_kwargs = { + "video_max_frames": 10, + "video_min_frames": 1, + } + limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1} + + model_name_or_path = "/ModelData/Qwen2.5-VL-7B-Instruct" + self.processor = QwenVLProcessor( + config=config, + model_name_or_path=model_name_or_path, + limit_mm_per_prompt=limit_mm_per_prompt, + mm_processor_kwargs=mm_processor_kwargs, + reasoning_parser_obj=None, + tool_parser_obj=None, + ) + + def tearDown(self) -> None: + """Clean up test case by stopping all mock patches""" + self.patcher_read_frames.stop() + self.patcher_parse_image.stop() + self.patcher_parse_video.stop() + + def test_process_request(self): + """ + Test processing of Request object with multimodal input + + Validates: + 1. Token ID lengths match position_ids and token_type_ids shapes + 2. Image processing produces expected output dimensions + 3. Video processing produces expected output dimensions + 4. Correct counts for images (1) and videos (1) + """ + prompt = { + "request_id": "12345", + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "file://demo.jpeg"}}, + {"type": "video_url", "video_url": {"url": "file://3_frame_video.mp4"}}, + {"type": "text", "text": "Describe image and video."}, + ], + } + ], + } + + request = Request.from_dict(prompt) + result = self.processor.process_request(request, 1024 * 100) + + self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0]) + self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0]) + self.assertEqual( + result.multimodal_inputs["images"].shape[0], + sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])), + ) + self.assertEqual( + result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum() + ) + self.assertEqual(result.multimodal_inputs["pic_cnt"], 1) + self.assertEqual(result.multimodal_inputs["video_cnt"], 1) + + def test_process_request_dict(self): + """ + Test processing of dictionary-format request with multimodal input + + Validates: + 1. Token ID lengths match position_ids and token_type_ids shapes + 2. Image processing produces expected output dimensions + 3. Video processing produces expected output dimensions + 4. Correct counts for images (1) and videos (1) + """ + num_generated_token_ids = 10 + request = { + "request_id": "12345", + "metadata": { + "generated_token_ids": [1] * num_generated_token_ids, + }, + "stop": ["stop", "eof"], + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "file://demo.jpeg"}}, + {"type": "video_url", "video_url": {"url": "file://3_frame_video.mp4"}}, + {"type": "text", "text": "Describe image and video."}, + ], + } + ], + } + + result = self.processor.process_request_dict(request, 1024 * 100) + + self.assertEqual(result["prompt_token_ids_len"], result["multimodal_inputs"]["position_ids"].shape[0]) + self.assertEqual(result["prompt_token_ids_len"], result["multimodal_inputs"]["token_type_ids"].shape[0]) + self.assertEqual( + result["multimodal_inputs"]["images"].shape[0], + sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])), + ) + self.assertEqual( + result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum() + ) + self.assertEqual(result["multimodal_inputs"]["pic_cnt"], 1) + self.assertEqual(result["multimodal_inputs"]["video_cnt"], 1) + + def test_prompt(self): + """ + Test processing of prompt with image and video placeholders + + Validates: + 1. Token ID lengths match position_ids and token_type_ids shapes + 2. Image processing produces expected output dimensions + 3. Video processing produces expected output dimensions + 4. Correct counts for images (1) and videos (1) + """ + prompt = { + "request_id": "12345", + "prompt": "<|image@placeholder|><|video@placeholder|>Describe image and video.", + "multimodal_data": { + "image": [mock_pil_image(10, 2100)], + "video": [{"video": b"123", "fps": 5}], + }, + } + + request = Request.from_dict(prompt) + result = self.processor.process_request(request, 1024 * 100) + + self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0]) + self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0]) + self.assertEqual( + result.multimodal_inputs["images"].shape[0], + sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])), + ) + self.assertEqual( + result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum() + ) + self.assertEqual(result.multimodal_inputs["pic_cnt"], 1) + self.assertEqual(result.multimodal_inputs["video_cnt"], 1) + + +if __name__ == "__main__": + unittest.main()