Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -18,11 +18,12 @@
""" process.py """
import copy
import io
import os
from collections import defaultdict
from typing import Any, Dict, List, Union
import numpy as np
from paddlenlp.transformers.image_utils import ChannelDimension
from paddleformers.transformers.image_utils import ChannelDimension
from PIL import Image
@@ -31,6 +32,8 @@ from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcess
from .process_video import read_frames_decord, read_video_decord
from .utils.io_utils import RAW_IMAGE_DIR, get_downloadable
from .utils.render_timestamp import render_frame_timestamp
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3}
@@ -94,9 +97,11 @@ class DataProcessor:
video_max_frames: int = 180,
video_min_frames: int = 16,
video_fps: int = 2,
**kwargs
) -> None:
# Tokenizer and image preprocessor
self.tokenizer = ErnieVLTokenizer.from_pretrained(tokenizer_name, verbose=False)
self.model_name_or_path = tokenizer_name
self._load_tokenizer()
self.tokenizer.ignored_index = -100
self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
@@ -125,6 +130,8 @@ class DataProcessor:
self.video_start = self.VID_START
self.video_end = self.VID_END
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
self.token_type_mapping = self._build_token_type_mapping()
self.is_training = True
@@ -145,11 +152,12 @@ class DataProcessor:
"""Enable evaluation mode (doesn't produce labels)."""
self.is_training = False
def process(self, messages: List[Dict[str, Any]]) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
def text2ids(self, text, images=None, videos=None):
"""
Convert chat messages into model inputs.
Convert chat text into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
@@ -162,37 +170,94 @@ class DataProcessor:
"pic_cnt": 0,
"video_cnt": 0,
}
self._add_special_token(self.cls_token, outputs)
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
VIDEO_PLACEHOLDER = "<|video@placeholder|>"
IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
st, image_idx, video_idx = 0, 0, 0
while st < len(text):
image_pos = text.find(IMAGE_PLACEHOLDER, st)
image_pos = len(text) if image_pos == -1 else image_pos
video_pos = text.find(VIDEO_PLACEHOLDER, st)
video_pos = len(text) if video_pos == -1 else video_pos
ed = min(image_pos, video_pos)
self._add_text(text[st:ed], outputs)
if ed == len(text):
break
if ed == image_pos:
self._add_image(images[image_idx], outputs)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
if isinstance(item, dict):
frames = self._load_and_process_video(item["video"], item)
else:
frames = self._load_and_process_video(item, {})
self._add_video(frames, outputs)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
return outputs
def request2ids(self, request: Dict[str, Any]) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
messages = parse_chat_messages(request.get("messages"))
image_message_list = []
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
prefix = self.role_prefixes[role]
if prefix:
self._add_text(prefix, outputs)
content_items = msg.get("content")
if not isinstance(content_items, list):
content_items = [content_items]
for item in content_items:
if isinstance(item, str) or item.get("type") == "text":
text = item if isinstance(item, str) else item.get("text", "")
self._add_text(text, outputs)
elif item.get("type") == "image_url" or item.get("type") == "image":
self._add_image(item, outputs)
elif item.get("type") == "video_url" or item.get("type") == "video":
self._add_video(item, outputs)
if role in ("user", "system"):
self._add_text("\n", outputs)
else:
self._add_special_token(self.sep_token, outputs)
if not self.is_training:
# Append assistant prefix in eval
self._add_text(self.role_prefixes["bot"], outputs)
if isinstance(item, dict) and item.get("type") in ["image", "video"]:
image_message_list.append(item)
prompt_token_ids = self.apply_chat_template(request)
image_start_index = 0
image_message_index = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in [self.image_start_id, self.video_start_id]:
self._add_text(prompt_token_ids[image_start_index:i + 1], outputs)
image_start_index = i + 1
image_message = image_message_list[image_message_index]
if image_message["type"] == "image":
img = image_message.get("image")
if img is None:
continue
outputs["pic_cnt"] += 1
self._add_image(img, outputs)
elif image_message["type"] == "video":
video_bytes = image_message.get("video")
if video_bytes is None:
continue
frames = self._load_and_process_video(video_bytes, image_message)
outputs["video_cnt"] += 1
self._add_video(frames, outputs)
image_message_index += 1
self._add_text(prompt_token_ids[image_start_index:], outputs)
return outputs
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
@@ -203,8 +268,9 @@ class DataProcessor:
outputs["position_ids"].append([pos] * 3)
outputs["cur_position"] += 1
def _add_text(self, text: str, outputs: Dict) -> None:
tokens = self.tokenizer.encode(text, add_special_tokens=False)["input_ids"]
def _add_text(self, tokens, outputs: Dict) -> None:
if isinstance(tokens, str):
tokens = self.tokenizer.encode(tokens, add_special_tokens=False)["input_ids"]
outputs["input_ids"].extend(tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens))
@@ -213,25 +279,7 @@ class DataProcessor:
outputs["position_ids"].append([start + i] * 3)
outputs["cur_position"] += len(tokens)
def _add_image(self, item: Dict, outputs: Dict) -> None:
url_info = item.get("image_url", {})
w = url_info.get("image_width", None)
h = url_info.get("image_height", None)
if "image" in item:
img = item["image"]
else:
url = url_info.get("url")
data = get_downloadable(url, download_dir=RAW_IMAGE_DIR, save_to_disk=False)
img = Image.open(io.BytesIO(data) if isinstance(data, bytes) else data)
if w and h:
img = img.resize((w, h))
outputs["pic_cnt"] += 1
self._add_text(f"Picture {outputs['pic_cnt']}:", outputs)
self._add_special_token(self.IMG_START, outputs)
def _add_image(self, img, outputs: Dict) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
img.height,
img.width,
@@ -260,21 +308,7 @@ class DataProcessor:
outputs["grid_thw"].append(ret["image_grid_thw"])
outputs["image_type_ids"].append(0)
self._add_special_token(self.IMG_END, outputs)
def _add_video(self, item: Dict, outputs: Dict) -> None:
url_info = item.get("video_url", {})
url = url_info.get("url")
outputs["video_cnt"] += 1
self._add_text(f"Video {outputs['video_cnt']}:", outputs)
self._add_special_token(self.VID_START, outputs)
if "video" in item:
video_path = item["video"]
frames = self._load_and_process_video(video_path, item)
else:
video_path = get_downloadable(url, save_to_disk=False)
frames = self._load_and_process_video(video_path, item)
def _add_video(self, frames, outputs: Dict) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
frames[0].height,
frames[0].width,
@@ -305,8 +339,6 @@ class DataProcessor:
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
self._add_special_token(self.VID_END, outputs)
def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]:
reader, meta, path = read_video_decord(url, save_to_disk=False)
@@ -386,3 +418,38 @@ class DataProcessor:
coords = list(zip(time_idx, h_idx, w_idx))
return [[start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords]
def _load_tokenizer(self):
"""
load tokenizer
Returns:
tokenizer (AutoTokenizer)
"""
vocab_file_names = ["tokenizer.model", "spm.model", "ernie_token_100k.model"]
for i in range(len(vocab_file_names)):
if os.path.exists(os.path.join(self.model_name_or_path, vocab_file_names[i])):
ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
break
self.tokenizer = ErnieBotTokenizer.from_pretrained(self.model_name_or_path)
def apply_chat_template(self, request):
"""
Convert multi-turn messages into ID sequences.
Args:
messages: Either a request dict containing 'messages' field,
or a list of message dicts directly
Returns:
List of token IDs as strings (converted from token objects)
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
prompt_token_str = self.tokenizer.apply_chat_template(
request, tokenize=False, add_generation_prompt=request.get("add_generation_prompt", True)
).replace("<|image@placeholder|>", "").replace("<|video@placeholder|>", "")
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
return token_ids