""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ """ process.py """ import copy import os from collections import defaultdict from typing import Any, Dict, List, Union import numpy as np from paddleformers.transformers.image_utils import ChannelDimension from PIL import Image from fastdeploy.entrypoints.chat_utils import parse_chat_messages from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.utils import data_processor_logger from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor from .process_video import read_frames_decord, read_video_decord from .utils.render_timestamp import render_frame_timestamp IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} def fancy_print(input_ids, tokenizer, image_patch_id=None): """ input_ids: input_ids tokenizer: the tokenizer of models """ i = 0 res = "" text_ids = [] real_image_token_len = 0 while i < len(input_ids): if input_ids[i] == image_patch_id: if len(text_ids) > 0: res += tokenizer.decode(text_ids) text_ids = [] real_image_token_len += 1 else: if real_image_token_len != 0: res += f"<|IMAGE@{real_image_token_len}|>" real_image_token_len = 0 text_ids.append(input_ids[i]) i += 1 if len(text_ids) > 0: res += tokenizer.decode(text_ids) text_ids = [] return res class DataProcessor: """ Processes multimodal chat messages into model-ready inputs, handling text, images, and videos with 3D positional embeddings. """ CLS_TOKEN = "<|begin_of_sentence|>" SEP_TOKEN = "<|end_of_sentence|>" EOS_TOKEN = "" IMG_START = "<|IMAGE_START|>" IMG_END = "<|IMAGE_END|>" VID_START = "<|VIDEO_START|>" VID_END = "<|VIDEO_END|>" def __init__( self, tokenizer_name: str, image_preprocessor_name: str, spatial_conv_size: int = 2, temporal_conv_size: int = 2, image_min_pixels: int = 4 * 28 * 28, image_max_pixels: int = 6177 * 28 * 28, video_min_pixels: int = 299 * 28 * 28, video_max_pixels: int = 1196 * 28 * 28, video_target_frames: int = -1, video_frames_sample: str = "leading", video_max_frames: int = 180, video_min_frames: int = 16, video_fps: int = 2, **kwargs, ) -> None: # Tokenizer and image preprocessor self.model_name_or_path = tokenizer_name self._load_tokenizer() self.tokenizer.ignored_index = -100 self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name) # Convolution sizes for patch aggregation self.spatial_conv_size = spatial_conv_size self.temporal_conv_size = temporal_conv_size # Pixel constraints self.image_min_pixels = image_min_pixels self.image_max_pixels = image_max_pixels self.video_min_pixels = video_min_pixels self.video_max_pixels = video_max_pixels # Video sampling parameters self.target_frames = video_target_frames self.frames_sample = video_frames_sample self.max_frames = video_max_frames self.min_frames = video_min_frames self.fps = video_fps # Special tokens and IDs self.cls_token = self.CLS_TOKEN self.sep_token = self.SEP_TOKEN self.eos_token = self.EOS_TOKEN self.image_start = self.IMG_START self.image_end = self.IMG_END self.video_start = self.VID_START self.video_end = self.VID_END self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>") self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start) self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start) self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token) self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token) self.token_type_mapping = self._build_token_type_mapping() self.is_training = True self.role_prefixes = { "system": "", "user": "User: ", "bot": "Assistant: ", "assistant": "Assistant: ", } def _build_token_type_mapping(self) -> Dict[Any, int]: mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"]) for token in ( self.IMG_START, self.IMG_END, self.VID_START, self.VID_END, ): mapping[token] = IDS_TYPE_FLAG["image"] mapping[self.image_patch_id] = IDS_TYPE_FLAG["image"] return mapping def train(self) -> None: """Enable training mode (produces labels).""" self.is_training = True def eval(self) -> None: """Enable evaluation mode (doesn't produce labels).""" self.is_training = False def text2ids(self, text, images=None, videos=None): """ Convert chat text into model inputs. Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. """ outputs = { "input_ids": [], "token_type_ids": [], "position_ids": [], "images": [], "grid_thw": [], "image_type_ids": [], "labels": [], "cur_position": 0, "pic_cnt": 0, "video_cnt": 0, } IMAGE_PLACEHOLDER = "<|image@placeholder|>" VIDEO_PLACEHOLDER = "<|video@placeholder|>" IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER) VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER) st, image_idx, video_idx = 0, 0, 0 while st < len(text): image_pos = text.find(IMAGE_PLACEHOLDER, st) image_pos = len(text) if image_pos == -1 else image_pos video_pos = text.find(VIDEO_PLACEHOLDER, st) video_pos = len(text) if video_pos == -1 else video_pos ed = min(image_pos, video_pos) self._add_text(text[st:ed], outputs) if ed == len(text): break if ed == image_pos: self._add_image(images[image_idx], outputs) image_idx += 1 st = ed + IMAGE_PLACEHOLDER_LEN else: item = videos[video_idx] if isinstance(item, dict): frames = self._load_and_process_video(item["video"], item) else: frames = self._load_and_process_video(item, {}) self._add_video(frames, outputs) video_idx += 1 st = ed + VIDEO_PLACEHOLDER_LEN return outputs def request2ids( self, request: Dict[str, Any], tgts: List[str] = None ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: """ Convert chat messages into model inputs. Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. """ outputs = { "input_ids": [], "token_type_ids": [], "position_ids": [], "images": [], "grid_thw": [], "image_type_ids": [], "labels": [], "cur_position": 0, "pic_cnt": 0, "video_cnt": 0, } messages = parse_chat_messages(request.get("messages")) image_message_list = [] for msg in messages: role = msg.get("role") assert role in self.role_prefixes, f"Unsupported role: {role}" content_items = msg.get("content") if not isinstance(content_items, list): content_items = [content_items] for item in content_items: if isinstance(item, dict) and item.get("type") in [ "image", "video", ]: image_message_list.append(item) prompt_token_ids = self.apply_chat_template(request) if len(prompt_token_ids) == 0: raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") image_start_index = 0 image_message_index = 0 for i in range(len(prompt_token_ids)): if prompt_token_ids[i] in [ self.image_start_id, self.video_start_id, ]: self._add_text(prompt_token_ids[image_start_index : i + 1], outputs) image_start_index = i + 1 image_message = image_message_list[image_message_index] if image_message["type"] == "image": img = image_message.get("image") if img is None: continue outputs["pic_cnt"] += 1 self._add_image(img, outputs) elif image_message["type"] == "video": video_bytes = image_message.get("video") if video_bytes is None: continue frames = self._load_and_process_video(video_bytes, image_message) outputs["video_cnt"] += 1 self._add_video(frames, outputs) image_message_index += 1 self._add_text(prompt_token_ids[image_start_index:], outputs) if self.is_training: assert tgts, "training must give tgt !" self._extract_labels(outputs, tgts) return outputs def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token) outputs["input_ids"].append(token_id) outputs["token_type_ids"].append(self.token_type_mapping[token]) pos = outputs["cur_position"] outputs["position_ids"].append([pos] * 3) outputs["cur_position"] += 1 def _add_text(self, tokens, outputs: Dict) -> None: if isinstance(tokens, str): tokens = self.tokenizer.encode(tokens, add_special_tokens=False)["input_ids"] outputs["input_ids"].extend(tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens)) start = outputs["cur_position"] for i in range(len(tokens)): outputs["position_ids"].append([start + i] * 3) outputs["cur_position"] += len(tokens) def _add_image(self, img, outputs: Dict) -> None: patches_h, patches_w = self.image_preprocessor.get_smarted_resize( img.height, img.width, min_pixels=self.image_min_pixels, max_pixels=self.image_max_pixels, )[1] num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2) outputs["input_ids"].extend([self.image_patch_id] * num_tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) pos_ids = self._compute_3d_positions(1, patches_h, patches_w, outputs["cur_position"]) outputs["position_ids"].extend(pos_ids) outputs["cur_position"] = np.max(pos_ids) + 1 # Preprocess pixels ret = self.image_preprocessor.preprocess( images=[img.convert("RGB")], do_normalize=False, do_rescale=False, predetermined_grid_thw=np.array([[patches_h, patches_w]]), do_convert_rgb=True, input_data_format=ChannelDimension.LAST, ) outputs["images"].append(ret["pixel_values"]) outputs["grid_thw"].append(ret["image_grid_thw"]) outputs["image_type_ids"].append(0) def _add_video(self, frames, outputs: Dict) -> None: patches_h, patches_w = self.image_preprocessor.get_smarted_resize( frames[0].height, frames[0].width, min_pixels=self.video_min_pixels, max_pixels=self.video_max_pixels, )[1] num_frames = len(frames) num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size) pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) ret = self.image_preprocessor.preprocess( images=None, videos=pixel_stack, do_normalize=False, do_rescale=False, predetermined_grid_thw=np.array([[patches_h, patches_w]] * num_frames), do_convert_rgb=True, input_data_format=ChannelDimension.LAST, ) outputs["images"].append(ret["pixel_values_videos"]) outputs["grid_thw"].append(ret["video_grid_thw"]) outputs["image_type_ids"].extend([1] * num_frames) outputs["input_ids"].extend([self.image_patch_id] * num_tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) pos_ids = self._compute_3d_positions(num_frames, patches_h, patches_w, outputs["cur_position"]) outputs["position_ids"].extend(pos_ids) outputs["cur_position"] = np.max(pos_ids) + 1 def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None: input_ids = copy.deepcopy(outputs["input_ids"]) labels = [self.tokenizer.ignored_index] * len(input_ids) tgt_count = input_ids.count(self.sep_token_id) assert tgt_count == len(tgts), f"len(tgts) != len(src) {len(tgts)} vs {tgt_count}" tgt_index = 0 for i, token_id in enumerate(input_ids): if token_id == self.sep_token_id: labels_token = self.tokenizer.tokenize(tgts[tgt_index]) labels_token_id = self.tokenizer.convert_tokens_to_ids(labels_token) labels[i - len(labels_token_id) : i] = labels_token_id labels[i] = self.eos_token_id # tgt_index += 1 outputs["labels"] = labels def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]: reader, meta, path = read_video_decord(url, save_to_disk=False) video_frame_args = dict() video_frame_args["fps"] = item.get("fps", self.fps) video_frame_args["min_frames"] = item.get("min_frames", self.min_frames) video_frame_args["max_frames"] = item.get("max_frames", self.max_frames) video_frame_args["target_frames"] = item.get("target_frames", self.target_frames) video_frame_args["frames_sample"] = item.get("frames_sample", self.frames_sample) video_frame_args = self._set_video_frame_args(video_frame_args, meta) frames_data, _, timestamps = read_frames_decord( path, reader, meta, target_frames=video_frame_args["target_frames"], target_fps=video_frame_args["fps"], frames_sample=video_frame_args["frames_sample"], save_to_disk=False, ) frames: List[Image.Image] = [] for img_array, ts in zip(frames_data, timestamps): frames.append(render_frame_timestamp(img_array, ts)) # Ensure even number of frames for temporal conv if len(frames) % 2 != 0: frames.append(copy.deepcopy(frames[-1])) return frames def _set_video_frame_args(self, video_frame_args, video_meta): """ 根据已知参数和优先级,设定最终的抽帧参数 """ # 优先级:video_target_frames > (video_min_frames, video_max_frames) > video_fps if video_frame_args["target_frames"] > 0: if video_frame_args["fps"] >= 0: raise ValueError("fps must be negative if target_frames is given") if ( video_frame_args["min_frames"] > 0 and video_frame_args["target_frames"] < video_frame_args["min_frames"] ): raise ValueError("target_frames must be larger than min_frames") if ( video_frame_args["max_frames"] > 0 and video_frame_args["target_frames"] > video_frame_args["max_frames"] ): raise ValueError("target_frames must be smaller than max_frames") else: if video_frame_args["fps"] < 0: raise ValueError("Must provide either positive target_fps or positive target_frames.") # 先计算在video_fps下抽到的帧数 frames_to_extract = int(video_meta["duration"] * video_frame_args["fps"]) # 判断是否在目标区间内,如果不是,则取target_frames为上界或下界 if ( video_frame_args["min_frames"] > 0 and video_frame_args["max_frames"] > 0 and video_frame_args["min_frames"] > video_frame_args["max_frames"] ): raise ValueError("min_frames must be smaller than max_frames") if video_frame_args["min_frames"] > 0 and frames_to_extract < video_frame_args["min_frames"]: video_frame_args["target_frames"] = video_frame_args["min_frames"] video_frame_args["fps"] = -1 if video_frame_args["max_frames"] > 0 and frames_to_extract > video_frame_args["max_frames"]: video_frame_args["target_frames"] = video_frame_args["max_frames"] video_frame_args["fps"] = -1 return video_frame_args def _compute_3d_positions(self, t: int, h: int, w: int, start_idx: int) -> List[List[int]]: # Downsample time if needed t_eff = t // self.temporal_conv_size if t != 1 else 1 gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size time_idx = np.repeat(np.arange(t_eff), gh * gw) h_idx = np.tile(np.repeat(np.arange(gh), gw), t_eff) w_idx = np.tile(np.arange(gw), t_eff * gh) coords = list(zip(time_idx, h_idx, w_idx)) return [[start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords] def _load_tokenizer(self): """ load tokenizer Returns: tokenizer (AutoTokenizer) """ vocab_file_names = [ "tokenizer.model", "spm.model", "ernie_token_100k.model", ] for i in range(len(vocab_file_names)): if os.path.exists(os.path.join(self.model_name_or_path, vocab_file_names[i])): ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] break self.tokenizer = ErnieBotTokenizer.from_pretrained(self.model_name_or_path) def apply_chat_template(self, request): """ Convert multi-turn messages into ID sequences. Args: messages: Either a request dict containing 'messages' field, or a list of message dicts directly Returns: List of token IDs as strings (converted from token objects) """ if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat_template.") prompt_token_template = self.tokenizer.apply_chat_template( request, tokenize=False, add_generation_prompt=request.get("add_generation_prompt", True), ) prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace( "<|video@placeholder|>", "" ) request["text_after_process"] = prompt_token_template tokens = self.tokenizer.tokenize(prompt_token_str) token_ids = self.tokenizer.convert_tokens_to_ids(tokens) data_processor_logger.info( f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}" ) return token_ids