Files
FastDeploy/fastdeploy/input/ernie4_5_vl_processor/process.py
luukunn 18f4977aec
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
[fix]update apply_chat_template (#4137)
* update apply_chat_template

* fix unittest

* fix unittest

* fix

* fix

* fix unit test

* fix

* fix unit test

* add unit test
2025-09-24 18:56:32 +08:00

513 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
""" process.py """
import copy
import os
from collections import defaultdict
from typing import Any, Dict, List, Union
import numpy as np
from paddleformers.transformers.image_utils import ChannelDimension
from PIL import Image
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.utils import data_processor_logger
from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor
from .process_video import read_frames_decord, read_video_decord
from .utils.render_timestamp import render_frame_timestamp
def fancy_print(input_ids, tokenizer, image_patch_id=None):
"""
input_ids: input_ids
tokenizer: the tokenizer of models
"""
i = 0
res = ""
text_ids = []
real_image_token_len = 0
while i < len(input_ids):
if input_ids[i] == image_patch_id:
if len(text_ids) > 0:
res += tokenizer.decode(text_ids)
text_ids = []
real_image_token_len += 1
else:
if real_image_token_len != 0:
res += f"<|IMAGE@{real_image_token_len}|>"
real_image_token_len = 0
text_ids.append(input_ids[i])
i += 1
if len(text_ids) > 0:
res += tokenizer.decode(text_ids)
text_ids = []
return res
class DataProcessor:
"""
Processes multimodal chat messages into model-ready inputs,
handling text, images, and videos with 3D positional embeddings.
"""
CLS_TOKEN = "<|begin_of_sentence|>"
SEP_TOKEN = "<|end_of_sentence|>"
EOS_TOKEN = "</s>"
IMG_START = "<|IMAGE_START|>"
IMG_END = "<|IMAGE_END|>"
VID_START = "<|VIDEO_START|>"
VID_END = "<|VIDEO_END|>"
def __init__(
self,
tokenizer_name: str,
image_preprocessor_name: str,
spatial_conv_size: int = 2,
temporal_conv_size: int = 2,
image_min_pixels: int = 4 * 28 * 28,
image_max_pixels: int = 6177 * 28 * 28,
video_min_pixels: int = 299 * 28 * 28,
video_max_pixels: int = 1196 * 28 * 28,
video_target_frames: int = -1,
video_frames_sample: str = "leading",
video_max_frames: int = 180,
video_min_frames: int = 16,
video_fps: int = 2,
**kwargs,
) -> None:
# Tokenizer and image preprocessor
self.model_name_or_path = tokenizer_name
self._load_tokenizer()
self.tokenizer.ignored_index = -100
self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
# Convolution sizes for patch aggregation
self.spatial_conv_size = spatial_conv_size
self.temporal_conv_size = temporal_conv_size
# Pixel constraints
self.image_min_pixels = image_min_pixels
self.image_max_pixels = image_max_pixels
self.video_min_pixels = video_min_pixels
self.video_max_pixels = video_max_pixels
# Video sampling parameters
self.target_frames = video_target_frames
self.frames_sample = video_frames_sample
self.max_frames = video_max_frames
self.min_frames = video_min_frames
self.fps = video_fps
# Special tokens and IDs
self.cls_token = self.CLS_TOKEN
self.sep_token = self.SEP_TOKEN
self.eos_token = self.EOS_TOKEN
self.image_start = self.IMG_START
self.image_end = self.IMG_END
self.video_start = self.VID_START
self.video_end = self.VID_END
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)
self.token_type_mapping = self._build_token_type_mapping()
self.is_training = True
self.role_prefixes = {
"system": "",
"user": "User: ",
"bot": "Assistant: ",
"assistant": "Assistant: ",
}
def _build_token_type_mapping(self) -> Dict[Any, int]:
mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"])
for token in (
self.IMG_START,
self.IMG_END,
self.VID_START,
self.VID_END,
):
mapping[token] = IDS_TYPE_FLAG["image"]
mapping[self.image_patch_id] = IDS_TYPE_FLAG["image"]
return mapping
def train(self) -> None:
"""Enable training mode (produces labels)."""
self.is_training = True
def eval(self) -> None:
"""Enable evaluation mode (doesn't produce labels)."""
self.is_training = False
def text2ids(self, text, images=None, videos=None):
"""
Convert chat text into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
VIDEO_PLACEHOLDER = "<|video@placeholder|>"
IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
st, image_idx, video_idx = 0, 0, 0
while st < len(text):
image_pos = text.find(IMAGE_PLACEHOLDER, st)
image_pos = len(text) if image_pos == -1 else image_pos
video_pos = text.find(VIDEO_PLACEHOLDER, st)
video_pos = len(text) if video_pos == -1 else video_pos
ed = min(image_pos, video_pos)
self._add_text(text[st:ed], outputs)
if ed == len(text):
break
if ed == image_pos:
self._add_image(images[image_idx], outputs)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
if isinstance(item, dict):
frames = self._load_and_process_video(item["video"], item)
else:
frames = self._load_and_process_video(item, {})
self._add_video(frames, outputs)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
return outputs
def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
messages = parse_chat_messages(request.get("messages"))
image_message_list = []
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
content_items = msg.get("content")
if not isinstance(content_items, list):
content_items = [content_items]
for item in content_items:
if isinstance(item, dict) and item.get("type") in [
"image",
"video",
]:
image_message_list.append(item)
chat_template_kwargs = request.get("chat_template_kwargs", {})
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
if len(prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
image_start_index = 0
image_message_index = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in [
self.image_start_id,
self.video_start_id,
]:
self._add_text(prompt_token_ids[image_start_index : i + 1], outputs)
image_start_index = i + 1
image_message = image_message_list[image_message_index]
if image_message["type"] == "image":
img = image_message.get("image")
if img is None:
continue
outputs["pic_cnt"] += 1
self._add_image(img, outputs)
elif image_message["type"] == "video":
video_bytes = image_message.get("video")
if video_bytes is None:
continue
frames = self._load_and_process_video(video_bytes, image_message)
outputs["video_cnt"] += 1
self._add_video(frames, outputs)
image_message_index += 1
self._add_text(prompt_token_ids[image_start_index:], outputs)
if self.is_training:
assert tgts, "training must give tgt !"
self._extract_labels(outputs, tgts)
return outputs
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token)
outputs["input_ids"].append(token_id)
outputs["token_type_ids"].append(self.token_type_mapping[token])
pos = outputs["cur_position"]
outputs["position_ids"].append([pos] * 3)
outputs["cur_position"] += 1
def _add_text(self, tokens, outputs: Dict) -> None:
if isinstance(tokens, str):
tokens = self.tokenizer.encode(tokens, add_special_tokens=False)["input_ids"]
outputs["input_ids"].extend(tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens))
start = outputs["cur_position"]
for i in range(len(tokens)):
outputs["position_ids"].append([start + i] * 3)
outputs["cur_position"] += len(tokens)
def _add_image(self, img, outputs: Dict) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
img.height,
img.width,
min_pixels=self.image_min_pixels,
max_pixels=self.image_max_pixels,
)[1]
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
pos_ids = self._compute_3d_positions(1, patches_h, patches_w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
# Preprocess pixels
ret = self.image_preprocessor.preprocess(
images=[img.convert("RGB")],
do_normalize=False,
do_rescale=False,
predetermined_grid_thw=np.array([[patches_h, patches_w]]),
do_convert_rgb=True,
input_data_format=ChannelDimension.LAST,
)
outputs["images"].append(ret["pixel_values"])
outputs["grid_thw"].append(ret["image_grid_thw"])
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
frames[0].height,
frames[0].width,
min_pixels=self.video_min_pixels,
max_pixels=self.video_max_pixels,
)[1]
num_frames = len(frames)
num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
ret = self.image_preprocessor.preprocess(
images=None,
videos=pixel_stack,
do_normalize=False,
do_rescale=False,
predetermined_grid_thw=np.array([[patches_h, patches_w]] * num_frames),
do_convert_rgb=True,
input_data_format=ChannelDimension.LAST,
)
outputs["images"].append(ret["pixel_values_videos"])
outputs["grid_thw"].append(ret["video_grid_thw"])
outputs["image_type_ids"].extend([1] * num_frames)
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
pos_ids = self._compute_3d_positions(num_frames, patches_h, patches_w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:
input_ids = copy.deepcopy(outputs["input_ids"])
labels = [self.tokenizer.ignored_index] * len(input_ids)
tgt_count = input_ids.count(self.sep_token_id)
assert tgt_count == len(tgts), f"len(tgts) != len(src) {len(tgts)} vs {tgt_count}"
tgt_index = 0
for i, token_id in enumerate(input_ids):
if token_id == self.sep_token_id:
labels_token = self.tokenizer.tokenize(tgts[tgt_index])
labels_token_id = self.tokenizer.convert_tokens_to_ids(labels_token)
labels[i - len(labels_token_id) : i] = labels_token_id
labels[i] = self.eos_token_id # </s>
tgt_index += 1
outputs["labels"] = labels
def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]:
reader, meta, path = read_video_decord(url, save_to_disk=False)
video_frame_args = dict()
video_frame_args["fps"] = item.get("fps", self.fps)
video_frame_args["min_frames"] = item.get("min_frames", self.min_frames)
video_frame_args["max_frames"] = item.get("max_frames", self.max_frames)
video_frame_args["target_frames"] = item.get("target_frames", self.target_frames)
video_frame_args["frames_sample"] = item.get("frames_sample", self.frames_sample)
video_frame_args = self._set_video_frame_args(video_frame_args, meta)
frames_data, _, timestamps = read_frames_decord(
path,
reader,
meta,
target_frames=video_frame_args["target_frames"],
target_fps=video_frame_args["fps"],
frames_sample=video_frame_args["frames_sample"],
save_to_disk=False,
)
frames: List[Image.Image] = []
for img_array, ts in zip(frames_data, timestamps):
frames.append(render_frame_timestamp(img_array, ts))
# Ensure even number of frames for temporal conv
if len(frames) % 2 != 0:
frames.append(copy.deepcopy(frames[-1]))
return frames
def _set_video_frame_args(self, video_frame_args, video_meta):
"""
根据已知参数和优先级,设定最终的抽帧参数
"""
# 优先级video_target_frames > (video_min_frames, video_max_frames) > video_fps
if video_frame_args["target_frames"] > 0:
if video_frame_args["fps"] >= 0:
raise ValueError("fps must be negative if target_frames is given")
if (
video_frame_args["min_frames"] > 0
and video_frame_args["target_frames"] < video_frame_args["min_frames"]
):
raise ValueError("target_frames must be larger than min_frames")
if (
video_frame_args["max_frames"] > 0
and video_frame_args["target_frames"] > video_frame_args["max_frames"]
):
raise ValueError("target_frames must be smaller than max_frames")
else:
if video_frame_args["fps"] < 0:
raise ValueError("Must provide either positive target_fps or positive target_frames.")
# 先计算在video_fps下抽到的帧数
frames_to_extract = int(video_meta["duration"] * video_frame_args["fps"])
# 判断是否在目标区间内如果不是则取target_frames为上界或下界
if (
video_frame_args["min_frames"] > 0
and video_frame_args["max_frames"] > 0
and video_frame_args["min_frames"] > video_frame_args["max_frames"]
):
raise ValueError("min_frames must be smaller than max_frames")
if video_frame_args["min_frames"] > 0 and frames_to_extract < video_frame_args["min_frames"]:
video_frame_args["target_frames"] = video_frame_args["min_frames"]
video_frame_args["fps"] = -1
if video_frame_args["max_frames"] > 0 and frames_to_extract > video_frame_args["max_frames"]:
video_frame_args["target_frames"] = video_frame_args["max_frames"]
video_frame_args["fps"] = -1
return video_frame_args
def _compute_3d_positions(self, t: int, h: int, w: int, start_idx: int) -> List[List[int]]:
# Downsample time if needed
t_eff = t // self.temporal_conv_size if t != 1 else 1
gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size
time_idx = np.repeat(np.arange(t_eff), gh * gw)
h_idx = np.tile(np.repeat(np.arange(gh), gw), t_eff)
w_idx = np.tile(np.arange(gw), t_eff * gh)
coords = list(zip(time_idx, h_idx, w_idx))
return [[start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords]
def _load_tokenizer(self):
"""
load tokenizer
Returns:
tokenizer (AutoTokenizer)
"""
vocab_file_names = [
"tokenizer.model",
"spm.model",
"ernie_token_100k.model",
]
for i in range(len(vocab_file_names)):
if os.path.exists(os.path.join(self.model_name_or_path, vocab_file_names[i])):
Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
break
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
def apply_chat_template(self, request, **kwargs):
"""
Convert multi-turn messages into ID sequences.
Args:
messages: Either a request dict containing 'messages' field,
or a list of message dicts directly
Returns:
List of token IDs as strings (converted from token objects)
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
prompt_token_template = self.tokenizer.apply_chat_template(
request,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**kwargs,
)
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
"<|video@placeholder|>", ""
)
request["text_after_process"] = prompt_token_template
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}"
)
return token_ids