mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			513 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			513 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | ||
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | ||
| #
 | ||
| # Licensed under the Apache License, Version 2.0 (the "License");
 | ||
| # you may not use this file except in compliance with the License.
 | ||
| # You may obtain a copy of the License at
 | ||
| #
 | ||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||
| #
 | ||
| # Unless required by applicable law or agreed to in writing, software
 | ||
| # distributed under the License is distributed on an "AS IS" BASIS,
 | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||
| 
 | ||
| # See the License for the specific language governing permissions and
 | ||
| # limitations under the License.
 | ||
| """
 | ||
| 
 | ||
| """ process.py """
 | ||
| import copy
 | ||
| import os
 | ||
| from collections import defaultdict
 | ||
| from typing import Any, Dict, List, Union
 | ||
| 
 | ||
| import numpy as np
 | ||
| from paddleformers.transformers.image_utils import ChannelDimension
 | ||
| from PIL import Image
 | ||
| 
 | ||
| from fastdeploy.entrypoints.chat_utils import parse_chat_messages
 | ||
| from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
 | ||
| from fastdeploy.input.utils import IDS_TYPE_FLAG
 | ||
| from fastdeploy.utils import data_processor_logger
 | ||
| 
 | ||
| from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor
 | ||
| from .process_video import read_frames_decord, read_video_decord
 | ||
| from .utils.render_timestamp import render_frame_timestamp
 | ||
| 
 | ||
| 
 | ||
| def fancy_print(input_ids, tokenizer, image_patch_id=None):
 | ||
|     """
 | ||
|     input_ids: input_ids
 | ||
|     tokenizer: the tokenizer of models
 | ||
|     """
 | ||
|     i = 0
 | ||
|     res = ""
 | ||
|     text_ids = []
 | ||
|     real_image_token_len = 0
 | ||
|     while i < len(input_ids):
 | ||
|         if input_ids[i] == image_patch_id:
 | ||
|             if len(text_ids) > 0:
 | ||
|                 res += tokenizer.decode(text_ids)
 | ||
|                 text_ids = []
 | ||
| 
 | ||
|             real_image_token_len += 1
 | ||
|         else:
 | ||
|             if real_image_token_len != 0:
 | ||
|                 res += f"<|IMAGE@{real_image_token_len}|>"
 | ||
|                 real_image_token_len = 0
 | ||
| 
 | ||
|             text_ids.append(input_ids[i])
 | ||
| 
 | ||
|         i += 1
 | ||
|     if len(text_ids) > 0:
 | ||
| 
 | ||
|         res += tokenizer.decode(text_ids)
 | ||
|         text_ids = []
 | ||
|     return res
 | ||
| 
 | ||
| 
 | ||
| class DataProcessor:
 | ||
|     """
 | ||
|     Processes multimodal chat messages into model-ready inputs,
 | ||
|     handling text, images, and videos with 3D positional embeddings.
 | ||
|     """
 | ||
| 
 | ||
|     CLS_TOKEN = "<|begin_of_sentence|>"
 | ||
|     SEP_TOKEN = "<|end_of_sentence|>"
 | ||
|     EOS_TOKEN = "</s>"
 | ||
|     IMG_START = "<|IMAGE_START|>"
 | ||
|     IMG_END = "<|IMAGE_END|>"
 | ||
|     VID_START = "<|VIDEO_START|>"
 | ||
|     VID_END = "<|VIDEO_END|>"
 | ||
| 
 | ||
|     def __init__(
 | ||
|         self,
 | ||
|         tokenizer_name: str,
 | ||
|         image_preprocessor_name: str,
 | ||
|         spatial_conv_size: int = 2,
 | ||
|         temporal_conv_size: int = 2,
 | ||
|         image_min_pixels: int = 4 * 28 * 28,
 | ||
|         image_max_pixels: int = 6177 * 28 * 28,
 | ||
|         video_min_pixels: int = 299 * 28 * 28,
 | ||
|         video_max_pixels: int = 1196 * 28 * 28,
 | ||
|         video_target_frames: int = -1,
 | ||
|         video_frames_sample: str = "leading",
 | ||
|         video_max_frames: int = 180,
 | ||
|         video_min_frames: int = 16,
 | ||
|         video_fps: int = 2,
 | ||
|         **kwargs,
 | ||
|     ) -> None:
 | ||
|         # Tokenizer and image preprocessor
 | ||
|         self.model_name_or_path = tokenizer_name
 | ||
|         self._load_tokenizer()
 | ||
|         self.tokenizer.ignored_index = -100
 | ||
|         self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
 | ||
| 
 | ||
|         # Convolution sizes for patch aggregation
 | ||
|         self.spatial_conv_size = spatial_conv_size
 | ||
|         self.temporal_conv_size = temporal_conv_size
 | ||
| 
 | ||
|         # Pixel constraints
 | ||
|         self.image_min_pixels = image_min_pixels
 | ||
|         self.image_max_pixels = image_max_pixels
 | ||
|         self.video_min_pixels = video_min_pixels
 | ||
|         self.video_max_pixels = video_max_pixels
 | ||
| 
 | ||
|         # Video sampling parameters
 | ||
|         self.target_frames = video_target_frames
 | ||
|         self.frames_sample = video_frames_sample
 | ||
|         self.max_frames = video_max_frames
 | ||
|         self.min_frames = video_min_frames
 | ||
|         self.fps = video_fps
 | ||
| 
 | ||
|         # Special tokens and IDs
 | ||
|         self.cls_token = self.CLS_TOKEN
 | ||
|         self.sep_token = self.SEP_TOKEN
 | ||
|         self.eos_token = self.EOS_TOKEN
 | ||
|         self.image_start = self.IMG_START
 | ||
|         self.image_end = self.IMG_END
 | ||
|         self.video_start = self.VID_START
 | ||
|         self.video_end = self.VID_END
 | ||
|         self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
 | ||
|         self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
 | ||
|         self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
 | ||
|         self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
 | ||
|         self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)
 | ||
| 
 | ||
|         self.token_type_mapping = self._build_token_type_mapping()
 | ||
|         self.is_training = True
 | ||
|         self.role_prefixes = {
 | ||
|             "system": "",
 | ||
|             "user": "User: ",
 | ||
|             "bot": "Assistant: ",
 | ||
|             "assistant": "Assistant: ",
 | ||
|         }
 | ||
| 
 | ||
|     def _build_token_type_mapping(self) -> Dict[Any, int]:
 | ||
|         mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"])
 | ||
|         for token in (
 | ||
|             self.IMG_START,
 | ||
|             self.IMG_END,
 | ||
|             self.VID_START,
 | ||
|             self.VID_END,
 | ||
|         ):
 | ||
|             mapping[token] = IDS_TYPE_FLAG["image"]
 | ||
|         mapping[self.image_patch_id] = IDS_TYPE_FLAG["image"]
 | ||
|         return mapping
 | ||
| 
 | ||
|     def train(self) -> None:
 | ||
|         """Enable training mode (produces labels)."""
 | ||
|         self.is_training = True
 | ||
| 
 | ||
|     def eval(self) -> None:
 | ||
|         """Enable evaluation mode (doesn't produce labels)."""
 | ||
|         self.is_training = False
 | ||
| 
 | ||
|     def text2ids(self, text, images=None, videos=None):
 | ||
|         """
 | ||
|         Convert chat text into model inputs.
 | ||
|         Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
 | ||
|         """
 | ||
| 
 | ||
|         outputs = {
 | ||
|             "input_ids": [],
 | ||
|             "token_type_ids": [],
 | ||
|             "position_ids": [],
 | ||
|             "images": [],
 | ||
|             "grid_thw": [],
 | ||
|             "image_type_ids": [],
 | ||
|             "labels": [],
 | ||
|             "cur_position": 0,
 | ||
|             "pic_cnt": 0,
 | ||
|             "video_cnt": 0,
 | ||
|         }
 | ||
| 
 | ||
|         IMAGE_PLACEHOLDER = "<|image@placeholder|>"
 | ||
|         VIDEO_PLACEHOLDER = "<|video@placeholder|>"
 | ||
|         IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
 | ||
|         VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
 | ||
|         st, image_idx, video_idx = 0, 0, 0
 | ||
|         while st < len(text):
 | ||
|             image_pos = text.find(IMAGE_PLACEHOLDER, st)
 | ||
|             image_pos = len(text) if image_pos == -1 else image_pos
 | ||
|             video_pos = text.find(VIDEO_PLACEHOLDER, st)
 | ||
|             video_pos = len(text) if video_pos == -1 else video_pos
 | ||
|             ed = min(image_pos, video_pos)
 | ||
| 
 | ||
|             self._add_text(text[st:ed], outputs)
 | ||
|             if ed == len(text):
 | ||
|                 break
 | ||
| 
 | ||
|             if ed == image_pos:
 | ||
|                 self._add_image(images[image_idx], outputs)
 | ||
|                 image_idx += 1
 | ||
|                 st = ed + IMAGE_PLACEHOLDER_LEN
 | ||
|             else:
 | ||
|                 item = videos[video_idx]
 | ||
|                 if isinstance(item, dict):
 | ||
|                     frames = self._load_and_process_video(item["video"], item)
 | ||
|                 else:
 | ||
|                     frames = self._load_and_process_video(item, {})
 | ||
| 
 | ||
|                 self._add_video(frames, outputs)
 | ||
|                 video_idx += 1
 | ||
|                 st = ed + VIDEO_PLACEHOLDER_LEN
 | ||
| 
 | ||
|         return outputs
 | ||
| 
 | ||
|     def request2ids(
 | ||
|         self, request: Dict[str, Any], tgts: List[str] = None
 | ||
|     ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
 | ||
|         """
 | ||
|         Convert chat messages into model inputs.
 | ||
|         Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
 | ||
|         """
 | ||
| 
 | ||
|         outputs = {
 | ||
|             "input_ids": [],
 | ||
|             "token_type_ids": [],
 | ||
|             "position_ids": [],
 | ||
|             "images": [],
 | ||
|             "grid_thw": [],
 | ||
|             "image_type_ids": [],
 | ||
|             "labels": [],
 | ||
|             "cur_position": 0,
 | ||
|             "pic_cnt": 0,
 | ||
|             "video_cnt": 0,
 | ||
|         }
 | ||
| 
 | ||
|         messages = parse_chat_messages(request.get("messages"))
 | ||
|         image_message_list = []
 | ||
|         for msg in messages:
 | ||
|             role = msg.get("role")
 | ||
|             assert role in self.role_prefixes, f"Unsupported role: {role}"
 | ||
|             content_items = msg.get("content")
 | ||
|             if not isinstance(content_items, list):
 | ||
|                 content_items = [content_items]
 | ||
|             for item in content_items:
 | ||
|                 if isinstance(item, dict) and item.get("type") in [
 | ||
|                     "image",
 | ||
|                     "video",
 | ||
|                 ]:
 | ||
|                     image_message_list.append(item)
 | ||
| 
 | ||
|         prompt_token_ids = self.apply_chat_template(request)
 | ||
|         if len(prompt_token_ids) == 0:
 | ||
|             raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
 | ||
|         image_start_index = 0
 | ||
|         image_message_index = 0
 | ||
|         for i in range(len(prompt_token_ids)):
 | ||
|             if prompt_token_ids[i] in [
 | ||
|                 self.image_start_id,
 | ||
|                 self.video_start_id,
 | ||
|             ]:
 | ||
|                 self._add_text(prompt_token_ids[image_start_index : i + 1], outputs)
 | ||
|                 image_start_index = i + 1
 | ||
|                 image_message = image_message_list[image_message_index]
 | ||
|                 if image_message["type"] == "image":
 | ||
|                     img = image_message.get("image")
 | ||
|                     if img is None:
 | ||
|                         continue
 | ||
|                     outputs["pic_cnt"] += 1
 | ||
|                     self._add_image(img, outputs)
 | ||
|                 elif image_message["type"] == "video":
 | ||
|                     video_bytes = image_message.get("video")
 | ||
|                     if video_bytes is None:
 | ||
|                         continue
 | ||
|                     frames = self._load_and_process_video(video_bytes, image_message)
 | ||
|                     outputs["video_cnt"] += 1
 | ||
|                     self._add_video(frames, outputs)
 | ||
|                 image_message_index += 1
 | ||
|         self._add_text(prompt_token_ids[image_start_index:], outputs)
 | ||
| 
 | ||
|         if self.is_training:
 | ||
|             assert tgts, "training must give tgt !"
 | ||
|             self._extract_labels(outputs, tgts)
 | ||
|         return outputs
 | ||
| 
 | ||
|     def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
 | ||
|         token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token)
 | ||
|         outputs["input_ids"].append(token_id)
 | ||
|         outputs["token_type_ids"].append(self.token_type_mapping[token])
 | ||
|         pos = outputs["cur_position"]
 | ||
|         outputs["position_ids"].append([pos] * 3)
 | ||
|         outputs["cur_position"] += 1
 | ||
| 
 | ||
|     def _add_text(self, tokens, outputs: Dict) -> None:
 | ||
|         if isinstance(tokens, str):
 | ||
|             tokens = self.tokenizer.encode(tokens, add_special_tokens=False)["input_ids"]
 | ||
|         outputs["input_ids"].extend(tokens)
 | ||
|         outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens))
 | ||
| 
 | ||
|         start = outputs["cur_position"]
 | ||
|         for i in range(len(tokens)):
 | ||
|             outputs["position_ids"].append([start + i] * 3)
 | ||
|         outputs["cur_position"] += len(tokens)
 | ||
| 
 | ||
|     def _add_image(self, img, outputs: Dict) -> None:
 | ||
|         patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
 | ||
|             img.height,
 | ||
|             img.width,
 | ||
|             min_pixels=self.image_min_pixels,
 | ||
|             max_pixels=self.image_max_pixels,
 | ||
|         )[1]
 | ||
|         num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
 | ||
| 
 | ||
|         outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
 | ||
|         outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
 | ||
| 
 | ||
|         pos_ids = self._compute_3d_positions(1, patches_h, patches_w, outputs["cur_position"])
 | ||
|         outputs["position_ids"].extend(pos_ids)
 | ||
|         outputs["cur_position"] = np.max(pos_ids) + 1
 | ||
| 
 | ||
|         # Preprocess pixels
 | ||
|         ret = self.image_preprocessor.preprocess(
 | ||
|             images=[img.convert("RGB")],
 | ||
|             do_normalize=False,
 | ||
|             do_rescale=False,
 | ||
|             predetermined_grid_thw=np.array([[patches_h, patches_w]]),
 | ||
|             do_convert_rgb=True,
 | ||
|             input_data_format=ChannelDimension.LAST,
 | ||
|         )
 | ||
|         outputs["images"].append(ret["pixel_values"])
 | ||
|         outputs["grid_thw"].append(ret["image_grid_thw"])
 | ||
|         outputs["image_type_ids"].append(0)
 | ||
| 
 | ||
|     def _add_video(self, frames, outputs: Dict) -> None:
 | ||
|         patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
 | ||
|             frames[0].height,
 | ||
|             frames[0].width,
 | ||
|             min_pixels=self.video_min_pixels,
 | ||
|             max_pixels=self.video_max_pixels,
 | ||
|         )[1]
 | ||
|         num_frames = len(frames)
 | ||
|         num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
 | ||
| 
 | ||
|         pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
 | ||
|         ret = self.image_preprocessor.preprocess(
 | ||
|             images=None,
 | ||
|             videos=pixel_stack,
 | ||
|             do_normalize=False,
 | ||
|             do_rescale=False,
 | ||
|             predetermined_grid_thw=np.array([[patches_h, patches_w]] * num_frames),
 | ||
|             do_convert_rgb=True,
 | ||
|             input_data_format=ChannelDimension.LAST,
 | ||
|         )
 | ||
|         outputs["images"].append(ret["pixel_values_videos"])
 | ||
|         outputs["grid_thw"].append(ret["video_grid_thw"])
 | ||
|         outputs["image_type_ids"].extend([1] * num_frames)
 | ||
| 
 | ||
|         outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
 | ||
|         outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
 | ||
| 
 | ||
|         pos_ids = self._compute_3d_positions(num_frames, patches_h, patches_w, outputs["cur_position"])
 | ||
|         outputs["position_ids"].extend(pos_ids)
 | ||
|         outputs["cur_position"] = np.max(pos_ids) + 1
 | ||
| 
 | ||
|     def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:
 | ||
|         input_ids = copy.deepcopy(outputs["input_ids"])
 | ||
|         labels = [self.tokenizer.ignored_index] * len(input_ids)
 | ||
| 
 | ||
|         tgt_count = input_ids.count(self.sep_token_id)
 | ||
|         assert tgt_count == len(tgts), f"len(tgts) != len(src) {len(tgts)} vs {tgt_count}"
 | ||
| 
 | ||
|         tgt_index = 0
 | ||
|         for i, token_id in enumerate(input_ids):
 | ||
|             if token_id == self.sep_token_id:
 | ||
|                 labels_token = self.tokenizer.tokenize(tgts[tgt_index])
 | ||
|                 labels_token_id = self.tokenizer.convert_tokens_to_ids(labels_token)
 | ||
|                 labels[i - len(labels_token_id) : i] = labels_token_id
 | ||
|                 labels[i] = self.eos_token_id  # </s>
 | ||
|                 tgt_index += 1
 | ||
| 
 | ||
|         outputs["labels"] = labels
 | ||
| 
 | ||
|     def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]:
 | ||
|         reader, meta, path = read_video_decord(url, save_to_disk=False)
 | ||
| 
 | ||
|         video_frame_args = dict()
 | ||
|         video_frame_args["fps"] = item.get("fps", self.fps)
 | ||
|         video_frame_args["min_frames"] = item.get("min_frames", self.min_frames)
 | ||
|         video_frame_args["max_frames"] = item.get("max_frames", self.max_frames)
 | ||
|         video_frame_args["target_frames"] = item.get("target_frames", self.target_frames)
 | ||
|         video_frame_args["frames_sample"] = item.get("frames_sample", self.frames_sample)
 | ||
| 
 | ||
|         video_frame_args = self._set_video_frame_args(video_frame_args, meta)
 | ||
| 
 | ||
|         frames_data, _, timestamps = read_frames_decord(
 | ||
|             path,
 | ||
|             reader,
 | ||
|             meta,
 | ||
|             target_frames=video_frame_args["target_frames"],
 | ||
|             target_fps=video_frame_args["fps"],
 | ||
|             frames_sample=video_frame_args["frames_sample"],
 | ||
|             save_to_disk=False,
 | ||
|         )
 | ||
| 
 | ||
|         frames: List[Image.Image] = []
 | ||
|         for img_array, ts in zip(frames_data, timestamps):
 | ||
|             frames.append(render_frame_timestamp(img_array, ts))
 | ||
|         # Ensure even number of frames for temporal conv
 | ||
|         if len(frames) % 2 != 0:
 | ||
|             frames.append(copy.deepcopy(frames[-1]))
 | ||
|         return frames
 | ||
| 
 | ||
|     def _set_video_frame_args(self, video_frame_args, video_meta):
 | ||
|         """
 | ||
|         根据已知参数和优先级,设定最终的抽帧参数
 | ||
|         """
 | ||
|         # 优先级:video_target_frames > (video_min_frames, video_max_frames) > video_fps
 | ||
|         if video_frame_args["target_frames"] > 0:
 | ||
|             if video_frame_args["fps"] >= 0:
 | ||
|                 raise ValueError("fps must be negative if target_frames is given")
 | ||
|             if (
 | ||
|                 video_frame_args["min_frames"] > 0
 | ||
|                 and video_frame_args["target_frames"] < video_frame_args["min_frames"]
 | ||
|             ):
 | ||
|                 raise ValueError("target_frames must be larger than min_frames")
 | ||
|             if (
 | ||
|                 video_frame_args["max_frames"] > 0
 | ||
|                 and video_frame_args["target_frames"] > video_frame_args["max_frames"]
 | ||
|             ):
 | ||
|                 raise ValueError("target_frames must be smaller than max_frames")
 | ||
|         else:
 | ||
|             if video_frame_args["fps"] < 0:
 | ||
|                 raise ValueError("Must provide either positive target_fps or positive target_frames.")
 | ||
|             # 先计算在video_fps下抽到的帧数
 | ||
|             frames_to_extract = int(video_meta["duration"] * video_frame_args["fps"])
 | ||
|             # 判断是否在目标区间内,如果不是,则取target_frames为上界或下界
 | ||
|             if (
 | ||
|                 video_frame_args["min_frames"] > 0
 | ||
|                 and video_frame_args["max_frames"] > 0
 | ||
|                 and video_frame_args["min_frames"] > video_frame_args["max_frames"]
 | ||
|             ):
 | ||
|                 raise ValueError("min_frames must be smaller than max_frames")
 | ||
|             if video_frame_args["min_frames"] > 0 and frames_to_extract < video_frame_args["min_frames"]:
 | ||
|                 video_frame_args["target_frames"] = video_frame_args["min_frames"]
 | ||
|                 video_frame_args["fps"] = -1
 | ||
|             if video_frame_args["max_frames"] > 0 and frames_to_extract > video_frame_args["max_frames"]:
 | ||
|                 video_frame_args["target_frames"] = video_frame_args["max_frames"]
 | ||
|                 video_frame_args["fps"] = -1
 | ||
| 
 | ||
|         return video_frame_args
 | ||
| 
 | ||
|     def _compute_3d_positions(self, t: int, h: int, w: int, start_idx: int) -> List[List[int]]:
 | ||
|         # Downsample time if needed
 | ||
|         t_eff = t // self.temporal_conv_size if t != 1 else 1
 | ||
|         gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size
 | ||
|         time_idx = np.repeat(np.arange(t_eff), gh * gw)
 | ||
|         h_idx = np.tile(np.repeat(np.arange(gh), gw), t_eff)
 | ||
|         w_idx = np.tile(np.arange(gw), t_eff * gh)
 | ||
| 
 | ||
|         coords = list(zip(time_idx, h_idx, w_idx))
 | ||
|         return [[start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords]
 | ||
| 
 | ||
|     def _load_tokenizer(self):
 | ||
|         """
 | ||
|         load tokenizer
 | ||
| 
 | ||
|         Returns:
 | ||
|             tokenizer (AutoTokenizer)
 | ||
|         """
 | ||
|         vocab_file_names = [
 | ||
|             "tokenizer.model",
 | ||
|             "spm.model",
 | ||
|             "ernie_token_100k.model",
 | ||
|         ]
 | ||
|         for i in range(len(vocab_file_names)):
 | ||
|             if os.path.exists(os.path.join(self.model_name_or_path, vocab_file_names[i])):
 | ||
|                 Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
 | ||
|                 break
 | ||
|         self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
 | ||
| 
 | ||
|     def apply_chat_template(self, request):
 | ||
|         """
 | ||
|         Convert multi-turn messages into ID sequences.
 | ||
| 
 | ||
|         Args:
 | ||
|             messages: Either a request dict containing 'messages' field,
 | ||
|                                 or a list of message dicts directly
 | ||
| 
 | ||
|         Returns:
 | ||
|             List of token IDs as strings (converted from token objects)
 | ||
|         """
 | ||
|         if self.tokenizer.chat_template is None:
 | ||
|             raise ValueError("This model does not support chat_template.")
 | ||
| 
 | ||
|         prompt_token_template = self.tokenizer.apply_chat_template(
 | ||
|             request,
 | ||
|             tokenize=False,
 | ||
|             add_generation_prompt=request.get("add_generation_prompt", True),
 | ||
|             chat_template=request.get("chat_template", None),
 | ||
|         )
 | ||
|         prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
 | ||
|             "<|video@placeholder|>", ""
 | ||
|         )
 | ||
|         request["text_after_process"] = prompt_token_template
 | ||
|         tokens = self.tokenizer.tokenize(prompt_token_str)
 | ||
|         token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
 | ||
|         data_processor_logger.info(
 | ||
|             f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}"
 | ||
|         )
 | ||
|         return token_ids
 | 
