mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -18,11 +18,12 @@
|
||||
""" process.py """
|
||||
import copy
|
||||
import io
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
from paddlenlp.transformers.image_utils import ChannelDimension
|
||||
from paddleformers.transformers.image_utils import ChannelDimension
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@@ -31,6 +32,8 @@ from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcess
|
||||
from .process_video import read_frames_decord, read_video_decord
|
||||
from .utils.io_utils import RAW_IMAGE_DIR, get_downloadable
|
||||
from .utils.render_timestamp import render_frame_timestamp
|
||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
|
||||
|
||||
IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3}
|
||||
|
||||
@@ -94,9 +97,11 @@ class DataProcessor:
|
||||
video_max_frames: int = 180,
|
||||
video_min_frames: int = 16,
|
||||
video_fps: int = 2,
|
||||
**kwargs
|
||||
) -> None:
|
||||
# Tokenizer and image preprocessor
|
||||
self.tokenizer = ErnieVLTokenizer.from_pretrained(tokenizer_name, verbose=False)
|
||||
self.model_name_or_path = tokenizer_name
|
||||
self._load_tokenizer()
|
||||
self.tokenizer.ignored_index = -100
|
||||
self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
|
||||
|
||||
@@ -125,6 +130,8 @@ class DataProcessor:
|
||||
self.video_start = self.VID_START
|
||||
self.video_end = self.VID_END
|
||||
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
|
||||
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
|
||||
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
|
||||
|
||||
self.token_type_mapping = self._build_token_type_mapping()
|
||||
self.is_training = True
|
||||
@@ -145,11 +152,12 @@ class DataProcessor:
|
||||
"""Enable evaluation mode (doesn't produce labels)."""
|
||||
self.is_training = False
|
||||
|
||||
def process(self, messages: List[Dict[str, Any]]) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
def text2ids(self, text, images=None, videos=None):
|
||||
"""
|
||||
Convert chat messages into model inputs.
|
||||
Convert chat text into model inputs.
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
@@ -162,37 +170,94 @@ class DataProcessor:
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
}
|
||||
self._add_special_token(self.cls_token, outputs)
|
||||
|
||||
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
|
||||
VIDEO_PLACEHOLDER = "<|video@placeholder|>"
|
||||
IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
|
||||
VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
|
||||
st, image_idx, video_idx = 0, 0, 0
|
||||
while st < len(text):
|
||||
image_pos = text.find(IMAGE_PLACEHOLDER, st)
|
||||
image_pos = len(text) if image_pos == -1 else image_pos
|
||||
video_pos = text.find(VIDEO_PLACEHOLDER, st)
|
||||
video_pos = len(text) if video_pos == -1 else video_pos
|
||||
ed = min(image_pos, video_pos)
|
||||
|
||||
self._add_text(text[st:ed], outputs)
|
||||
if ed == len(text):
|
||||
break
|
||||
|
||||
if ed == image_pos:
|
||||
self._add_image(images[image_idx], outputs)
|
||||
image_idx += 1
|
||||
st = ed + IMAGE_PLACEHOLDER_LEN
|
||||
else:
|
||||
item = videos[video_idx]
|
||||
if isinstance(item, dict):
|
||||
frames = self._load_and_process_video(item["video"], item)
|
||||
else:
|
||||
frames = self._load_and_process_video(item, {})
|
||||
|
||||
self._add_video(frames, outputs)
|
||||
video_idx += 1
|
||||
st = ed + VIDEO_PLACEHOLDER_LEN
|
||||
|
||||
return outputs
|
||||
|
||||
def request2ids(self, request: Dict[str, Any]) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
"""
|
||||
Convert chat messages into model inputs.
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
}
|
||||
|
||||
messages = parse_chat_messages(request.get("messages"))
|
||||
image_message_list = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
assert role in self.role_prefixes, f"Unsupported role: {role}"
|
||||
prefix = self.role_prefixes[role]
|
||||
if prefix:
|
||||
self._add_text(prefix, outputs)
|
||||
|
||||
content_items = msg.get("content")
|
||||
if not isinstance(content_items, list):
|
||||
content_items = [content_items]
|
||||
|
||||
for item in content_items:
|
||||
if isinstance(item, str) or item.get("type") == "text":
|
||||
text = item if isinstance(item, str) else item.get("text", "")
|
||||
self._add_text(text, outputs)
|
||||
elif item.get("type") == "image_url" or item.get("type") == "image":
|
||||
self._add_image(item, outputs)
|
||||
elif item.get("type") == "video_url" or item.get("type") == "video":
|
||||
self._add_video(item, outputs)
|
||||
|
||||
if role in ("user", "system"):
|
||||
self._add_text("\n", outputs)
|
||||
else:
|
||||
self._add_special_token(self.sep_token, outputs)
|
||||
|
||||
if not self.is_training:
|
||||
# Append assistant prefix in eval
|
||||
self._add_text(self.role_prefixes["bot"], outputs)
|
||||
|
||||
if isinstance(item, dict) and item.get("type") in ["image", "video"]:
|
||||
image_message_list.append(item)
|
||||
|
||||
prompt_token_ids = self.apply_chat_template(request)
|
||||
image_start_index = 0
|
||||
image_message_index = 0
|
||||
for i in range(len(prompt_token_ids)):
|
||||
if prompt_token_ids[i] in [self.image_start_id, self.video_start_id]:
|
||||
self._add_text(prompt_token_ids[image_start_index:i + 1], outputs)
|
||||
image_start_index = i + 1
|
||||
image_message = image_message_list[image_message_index]
|
||||
if image_message["type"] == "image":
|
||||
img = image_message.get("image")
|
||||
if img is None:
|
||||
continue
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(img, outputs)
|
||||
elif image_message["type"] == "video":
|
||||
video_bytes = image_message.get("video")
|
||||
if video_bytes is None:
|
||||
continue
|
||||
frames = self._load_and_process_video(video_bytes, image_message)
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, outputs)
|
||||
image_message_index += 1
|
||||
self._add_text(prompt_token_ids[image_start_index:], outputs)
|
||||
return outputs
|
||||
|
||||
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
|
||||
@@ -203,8 +268,9 @@ class DataProcessor:
|
||||
outputs["position_ids"].append([pos] * 3)
|
||||
outputs["cur_position"] += 1
|
||||
|
||||
def _add_text(self, text: str, outputs: Dict) -> None:
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)["input_ids"]
|
||||
def _add_text(self, tokens, outputs: Dict) -> None:
|
||||
if isinstance(tokens, str):
|
||||
tokens = self.tokenizer.encode(tokens, add_special_tokens=False)["input_ids"]
|
||||
outputs["input_ids"].extend(tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens))
|
||||
|
||||
@@ -213,25 +279,7 @@ class DataProcessor:
|
||||
outputs["position_ids"].append([start + i] * 3)
|
||||
outputs["cur_position"] += len(tokens)
|
||||
|
||||
def _add_image(self, item: Dict, outputs: Dict) -> None:
|
||||
url_info = item.get("image_url", {})
|
||||
w = url_info.get("image_width", None)
|
||||
h = url_info.get("image_height", None)
|
||||
|
||||
if "image" in item:
|
||||
img = item["image"]
|
||||
else:
|
||||
url = url_info.get("url")
|
||||
data = get_downloadable(url, download_dir=RAW_IMAGE_DIR, save_to_disk=False)
|
||||
img = Image.open(io.BytesIO(data) if isinstance(data, bytes) else data)
|
||||
|
||||
if w and h:
|
||||
img = img.resize((w, h))
|
||||
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_text(f"Picture {outputs['pic_cnt']}:", outputs)
|
||||
self._add_special_token(self.IMG_START, outputs)
|
||||
|
||||
def _add_image(self, img, outputs: Dict) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
img.height,
|
||||
img.width,
|
||||
@@ -260,21 +308,7 @@ class DataProcessor:
|
||||
outputs["grid_thw"].append(ret["image_grid_thw"])
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
self._add_special_token(self.IMG_END, outputs)
|
||||
|
||||
def _add_video(self, item: Dict, outputs: Dict) -> None:
|
||||
url_info = item.get("video_url", {})
|
||||
url = url_info.get("url")
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_text(f"Video {outputs['video_cnt']}:", outputs)
|
||||
self._add_special_token(self.VID_START, outputs)
|
||||
|
||||
if "video" in item:
|
||||
video_path = item["video"]
|
||||
frames = self._load_and_process_video(video_path, item)
|
||||
else:
|
||||
video_path = get_downloadable(url, save_to_disk=False)
|
||||
frames = self._load_and_process_video(video_path, item)
|
||||
def _add_video(self, frames, outputs: Dict) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
frames[0].height,
|
||||
frames[0].width,
|
||||
@@ -305,8 +339,6 @@ class DataProcessor:
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
self._add_special_token(self.VID_END, outputs)
|
||||
|
||||
def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]:
|
||||
reader, meta, path = read_video_decord(url, save_to_disk=False)
|
||||
|
||||
@@ -386,3 +418,38 @@ class DataProcessor:
|
||||
|
||||
coords = list(zip(time_idx, h_idx, w_idx))
|
||||
return [[start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords]
|
||||
|
||||
def _load_tokenizer(self):
|
||||
"""
|
||||
load tokenizer
|
||||
|
||||
Returns:
|
||||
tokenizer (AutoTokenizer)
|
||||
"""
|
||||
vocab_file_names = ["tokenizer.model", "spm.model", "ernie_token_100k.model"]
|
||||
for i in range(len(vocab_file_names)):
|
||||
if os.path.exists(os.path.join(self.model_name_or_path, vocab_file_names[i])):
|
||||
ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
|
||||
break
|
||||
self.tokenizer = ErnieBotTokenizer.from_pretrained(self.model_name_or_path)
|
||||
|
||||
def apply_chat_template(self, request):
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
|
||||
Args:
|
||||
messages: Either a request dict containing 'messages' field,
|
||||
or a list of message dicts directly
|
||||
|
||||
Returns:
|
||||
List of token IDs as strings (converted from token objects)
|
||||
"""
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
|
||||
prompt_token_str = self.tokenizer.apply_chat_template(
|
||||
request, tokenize=False, add_generation_prompt=request.get("add_generation_prompt", True)
|
||||
).replace("<|image@placeholder|>", "").replace("<|video@placeholder|>", "")
|
||||
tokens = self.tokenizer.tokenize(prompt_token_str)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
return token_ids
|
Reference in New Issue
Block a user