mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] mm support prefix cache (#4134)
* support mm prefix caching * update code * fix mm_hashes * support encoder cache * add encoder cache * update code * update encoder cache * fix features bug * fix worker bug * support processor cache, need to optimize yet * refactor multimodal data cache * update code * update code * update v1 scheduler * update code * update code * update codestyle * support turn off processor cache and encoder cache * update pre-commit * fix code * solve review * update code * update code * update test case * set processor cache in GiB * update test case * support mm prefix caching for qwen model * fix code style check * update pre-commit * fix unit test * fix unit test * add ci test case * fix rescheduled bug * change text_after_process to prompt_tokens * fix unit test * fix chat template * change model path * [EP] fix adapter bugs (#4572) * Update expert_service.py * Update common_engine.py * Update expert_service.py * fix v1 hang bug (#4573) * fix import image_ops error on some platforms (#4559) * [CLI]Update parameters in bench latecy cli tool and fix collect-env cli tool (#4558) * add collect-env * del files * [Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578) * add new branch for sot * reorder * fix batch bug * [XPU]Moe uses a new operator (#4585) * [XPU]Moe uses a new operator * [XPU]Moe uses a new operator * update response * [Feature] Support Paddle-OCR (#4396) * init * update code * fix code style & disable thinking * adapt for common_engine.update_mm_requests_chunk_size * use 3d rope * use flash_attn_unpadded * opt siglip * update to be compatible with the latest codebase * fix typo * optim OCR performance * fix bug * fix bug * fix bug * fix bug * normlize name * modify xpu rope * revert logger * fix bug * fix bug * fix bug * support default_v1 * optim performance * fix bug --------- Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com> * [DataProcessor] add reasoning_tokens into usage info (#4520) * add reasoning_tokens into usage info initial commit * add unit tests * modify unit test * modify and add unit tests * fix unit test * move steam usage to processor * modify processor * modify test_logprobs * modify test_logprobs.py * modify stream reasoning tokens accumulation * fix unit test * perf: Optimize task queue communication from engine to worker (#4531) * perf: Optimize task queue communication from engine to worker * perf: get_tasks to numpy * perf: get_tasks remove to_numpy * fix: request & replace ENV * remove test_e2w_perf.py * fix code style --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> * Clean up ports after processing results (#4587) * [CI] Add /re-run command in PR comments to restart failed CI workflows (#4593) * [Others] api server exits when worker process is dead (#3271) * [fix] fix terminal hangs when worker process is dead * [chore] change sleep time of monitor * [chore] remove redundant comments * update docs --------- Co-authored-by: ApplEOFDiscord <wwy640130@163.com> Co-authored-by: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com> Co-authored-by: yinwei <yinwei_hust@163.com> Co-authored-by: JYChen <zoooo0820@qq.com> Co-authored-by: qwes5s5 <45442318+qwes5s5@users.noreply.github.com> Co-authored-by: Ryan <zihaohuang@aliyun.com> Co-authored-by: yyssys <atyangshuang@foxmail.com> Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com> Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com> Co-authored-by: kxz2002 <115912648+kxz2002@users.noreply.github.com> Co-authored-by: SunLei <sunlei5788@gmail.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
This commit is contained in:
@@ -37,6 +37,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
mm_processor_kwargs=None,
|
||||
reasoning_parser_obj=None,
|
||||
tool_parser_obj=None,
|
||||
enable_processor_cache=False,
|
||||
):
|
||||
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
|
||||
tokenizer_path = model_name_or_path
|
||||
@@ -46,6 +47,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
self.ernie4_5_processor = DataProcessor(
|
||||
tokenizer_name=tokenizer_path,
|
||||
image_preprocessor_name=preprocessor_path,
|
||||
enable_processor_cache=enable_processor_cache,
|
||||
**processor_kwargs,
|
||||
)
|
||||
self.ernie4_5_processor.eval()
|
||||
|
||||
@@ -18,16 +18,20 @@
|
||||
""" process.py """
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
from paddleformers.transformers.image_utils import ChannelDimension
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.engine.request import ImagePosition
|
||||
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
||||
from fastdeploy.multimodal.hasher import MultimodalHasher
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor
|
||||
@@ -84,6 +88,7 @@ class DataProcessor:
|
||||
self,
|
||||
tokenizer_name: str,
|
||||
image_preprocessor_name: str,
|
||||
enable_processor_cache: bool = False,
|
||||
spatial_conv_size: int = 2,
|
||||
temporal_conv_size: int = 2,
|
||||
image_min_pixels: int = 4 * 28 * 28,
|
||||
@@ -102,6 +107,7 @@ class DataProcessor:
|
||||
self._load_tokenizer()
|
||||
self.tokenizer.ignored_index = -100
|
||||
self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
|
||||
self.enable_processor_cache = enable_processor_cache
|
||||
|
||||
# Convolution sizes for patch aggregation
|
||||
self.spatial_conv_size = spatial_conv_size
|
||||
@@ -163,10 +169,18 @@ class DataProcessor:
|
||||
"""Enable evaluation mode (doesn't produce labels)."""
|
||||
self.is_training = False
|
||||
|
||||
def text2ids(self, text, images=None, videos=None):
|
||||
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
|
||||
"""
|
||||
Convert chat text into model inputs.
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
|
||||
Args:
|
||||
text (str): The chat text containing placeholders for images and videos.
|
||||
images (list, optional): List of images to be processed and inserted at image placeholders.
|
||||
videos (list, optional): List of videos to be processed and inserted at video placeholders.
|
||||
image_uuid (list, optional): List of unique identifiers for each image, used for caching or hashing.
|
||||
video_uuid (list, optional): List of unique identifiers for each video, used for caching or hashing.
|
||||
Returns:
|
||||
dict: A dictionary with keys input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels, etc.
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
@@ -178,8 +192,9 @@ class DataProcessor:
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
"mm_positions": [],
|
||||
"mm_hashes": [],
|
||||
}
|
||||
|
||||
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
|
||||
@@ -199,17 +214,27 @@ class DataProcessor:
|
||||
break
|
||||
|
||||
if ed == image_pos:
|
||||
self._add_image(images[image_idx], outputs)
|
||||
image = images[image_idx]
|
||||
uuid = image_uuid[image_idx] if image_uuid else None
|
||||
if not isinstance(image, tuple):
|
||||
self._add_image(image, outputs, uuid)
|
||||
else:
|
||||
# cached images are already processed
|
||||
self._add_processed_image(image, outputs, uuid)
|
||||
image_idx += 1
|
||||
st = ed + IMAGE_PLACEHOLDER_LEN
|
||||
else:
|
||||
item = videos[video_idx]
|
||||
if isinstance(item, dict):
|
||||
frames = self._load_and_process_video(item["video"], item)
|
||||
uuid = video_uuid[video_idx] if video_uuid else None
|
||||
if not isinstance(item, tuple):
|
||||
if isinstance(item, dict):
|
||||
frames = self._load_and_process_video(item["video"], item)
|
||||
else:
|
||||
frames = self._load_and_process_video(item, {})
|
||||
self._add_video(frames, outputs, uuid)
|
||||
else:
|
||||
frames = self._load_and_process_video(item, {})
|
||||
|
||||
self._add_video(frames, outputs)
|
||||
# cached frames are already processed
|
||||
self._add_processed_video(item, outputs, uuid)
|
||||
video_idx += 1
|
||||
st = ed + VIDEO_PLACEHOLDER_LEN
|
||||
|
||||
@@ -223,66 +248,82 @@ class DataProcessor:
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
}
|
||||
|
||||
messages = parse_chat_messages(request.get("messages"))
|
||||
image_message_list = []
|
||||
mm_items = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
assert role in self.role_prefixes, f"Unsupported role: {role}"
|
||||
content_items = msg.get("content")
|
||||
if not isinstance(content_items, list):
|
||||
content_items = [content_items]
|
||||
for item in content_items:
|
||||
if isinstance(item, dict) and item.get("type") in [
|
||||
"image",
|
||||
"video",
|
||||
]:
|
||||
image_message_list.append(item)
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
for item in content:
|
||||
if item.get("type") in ["image", "video"]:
|
||||
mm_items.append(item)
|
||||
|
||||
missing_hashes, missing_idx = [], []
|
||||
for idx, item in enumerate(mm_items):
|
||||
if not item.get("data"):
|
||||
# raw data not provided, should be retrieved from processor cache
|
||||
missing_hashes.append(item.get("uuid"))
|
||||
missing_idx.append(idx)
|
||||
|
||||
if len(missing_hashes) > 0 and not self.enable_processor_cache:
|
||||
raise ValueError("Missing items cannot be retrieved without processor cache.")
|
||||
|
||||
if self.enable_processor_cache:
|
||||
context = zmq.Context()
|
||||
dealer = context.socket(zmq.DEALER)
|
||||
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
|
||||
|
||||
missing_items = self.get_processor_cache(dealer, missing_hashes)
|
||||
for idx in range(len(missing_items)):
|
||||
if not missing_items[idx]:
|
||||
raise ValueError(f"Missing item {idx} not found in processor cache")
|
||||
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
|
||||
|
||||
images, videos = [], []
|
||||
image_uuid, video_uuid = [], []
|
||||
for item in mm_items:
|
||||
if item.get("type") == "image":
|
||||
images.append(item["data"])
|
||||
image_uuid.append(item["uuid"])
|
||||
elif item.get("type") == "video":
|
||||
videos.append(item["data"])
|
||||
video_uuid.append(item["uuid"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
|
||||
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat template.")
|
||||
|
||||
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
|
||||
if len(prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
image_start_index = 0
|
||||
image_message_index = 0
|
||||
for i in range(len(prompt_token_ids)):
|
||||
if prompt_token_ids[i] in [
|
||||
self.image_start_id,
|
||||
self.video_start_id,
|
||||
]:
|
||||
self._add_text(prompt_token_ids[image_start_index : i + 1], outputs)
|
||||
image_start_index = i + 1
|
||||
image_message = image_message_list[image_message_index]
|
||||
if image_message["type"] == "image":
|
||||
img = image_message.get("image")
|
||||
if img is None:
|
||||
continue
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(img, outputs)
|
||||
elif image_message["type"] == "video":
|
||||
video_bytes = image_message.get("video")
|
||||
if video_bytes is None:
|
||||
continue
|
||||
frames = self._load_and_process_video(video_bytes, image_message)
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, outputs)
|
||||
image_message_index += 1
|
||||
self._add_text(prompt_token_ids[image_start_index:], outputs)
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
request,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
request["prompt_tokens"] = prompt
|
||||
|
||||
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
|
||||
|
||||
if self.enable_processor_cache:
|
||||
missing_idx = set(missing_idx)
|
||||
hashes_to_cache, items_to_cache = [], []
|
||||
for idx in range(len(mm_items)):
|
||||
if idx in missing_idx:
|
||||
continue
|
||||
meta = {}
|
||||
t, h, w = outputs["grid_thw"][idx][0]
|
||||
meta["thw"] = (t, h, w)
|
||||
hashes_to_cache.append(outputs["mm_hashes"][idx])
|
||||
items_to_cache.append((outputs["images"][idx], meta))
|
||||
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
|
||||
|
||||
if self.is_training:
|
||||
assert tgts, "training must give tgt !"
|
||||
assert tgts, "Training must give tgt"
|
||||
self._extract_labels(outputs, tgts)
|
||||
|
||||
return outputs
|
||||
|
||||
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
|
||||
@@ -304,7 +345,7 @@ class DataProcessor:
|
||||
outputs["position_ids"].append([start + i] * 3)
|
||||
outputs["cur_position"] += len(tokens)
|
||||
|
||||
def _add_image(self, img, outputs: Dict) -> None:
|
||||
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
img.height,
|
||||
img.width,
|
||||
@@ -313,6 +354,7 @@ class DataProcessor:
|
||||
)[1]
|
||||
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
@@ -330,10 +372,32 @@ class DataProcessor:
|
||||
input_data_format=ChannelDimension.LAST,
|
||||
)
|
||||
outputs["images"].append(ret["pixel_values"])
|
||||
if not uuid:
|
||||
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
|
||||
else:
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(ret["image_grid_thw"])
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
def _add_video(self, frames, outputs: Dict) -> None:
|
||||
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
img, meta = img_cache
|
||||
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
_, h, w = meta["thw"]
|
||||
pos_ids = self._compute_3d_positions(1, h, w, outputs["cur_position"])
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
outputs["images"].append(img)
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(np.array([[1, h, w]]))
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
frames[0].height,
|
||||
frames[0].width,
|
||||
@@ -354,9 +418,14 @@ class DataProcessor:
|
||||
input_data_format=ChannelDimension.LAST,
|
||||
)
|
||||
outputs["images"].append(ret["pixel_values_videos"])
|
||||
if not uuid:
|
||||
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values_videos"]))
|
||||
else:
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(ret["video_grid_thw"])
|
||||
outputs["image_type_ids"].extend([1] * num_frames)
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
|
||||
@@ -364,6 +433,24 @@ class DataProcessor:
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
frames, meta = frames_cache
|
||||
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
|
||||
|
||||
t, h, w = meta["thw"]
|
||||
outputs["images"].append(frames)
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(np.array([[t, h, w]]))
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
outputs["image_type_ids"].extend([1] * t)
|
||||
|
||||
pos_ids = self._compute_3d_positions(t, h, w, outputs["cur_position"])
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:
|
||||
input_ids = copy.deepcopy(outputs["input_ids"])
|
||||
labels = [self.tokenizer.ignored_index] * len(input_ids)
|
||||
@@ -480,33 +567,22 @@ class DataProcessor:
|
||||
break
|
||||
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
|
||||
|
||||
def apply_chat_template(self, request, **kwargs):
|
||||
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
|
||||
Args:
|
||||
messages: Either a request dict containing 'messages' field,
|
||||
or a list of message dicts directly
|
||||
|
||||
Returns:
|
||||
List of token IDs as strings (converted from token objects)
|
||||
get cache correspond to given hash values
|
||||
"""
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
req = pickle.dumps(mm_hashes)
|
||||
socket.send_multipart([b"", req])
|
||||
_, resp = socket.recv_multipart()
|
||||
mm_items = pickle.loads(resp)
|
||||
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
|
||||
|
||||
prompt_token_template = self.tokenizer.apply_chat_template(
|
||||
request,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
**kwargs,
|
||||
)
|
||||
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
|
||||
"<|video@placeholder|>", ""
|
||||
)
|
||||
request["prompt_tokens"] = prompt_token_template
|
||||
tokens = self.tokenizer.tokenize(prompt_token_str)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
data_processor_logger.info(
|
||||
f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
return token_ids
|
||||
return mm_items
|
||||
|
||||
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
|
||||
"""
|
||||
update cache data
|
||||
"""
|
||||
req = pickle.dumps((mm_hashes, mm_items))
|
||||
socket.send_multipart([b"", req])
|
||||
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")
|
||||
|
||||
@@ -46,6 +46,7 @@ class InputPreprocessor:
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tool_parser: str = None,
|
||||
enable_processor_cache: bool = False,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.model_name_or_path = self.model_config.model
|
||||
@@ -53,6 +54,7 @@ class InputPreprocessor:
|
||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.tool_parser = tool_parser
|
||||
self.enable_processor_cache = enable_processor_cache
|
||||
|
||||
def create_processor(self):
|
||||
reasoning_parser_obj = None
|
||||
@@ -104,6 +106,19 @@ class InputPreprocessor:
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
tool_parser_obj=tool_parser_obj,
|
||||
enable_processor_cache=self.enable_processor_cache,
|
||||
)
|
||||
elif "PaddleOCRVL" in architecture:
|
||||
from fastdeploy.input.paddleocr_vl_processor import (
|
||||
PaddleOCRVLProcessor,
|
||||
)
|
||||
|
||||
self.processor = PaddleOCRVLProcessor(
|
||||
config=self.model_config,
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
)
|
||||
elif "PaddleOCRVL" in architecture:
|
||||
from fastdeploy.input.paddleocr_vl_processor import (
|
||||
@@ -126,5 +141,6 @@ class InputPreprocessor:
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
enable_processor_cache=self.enable_processor_cache,
|
||||
)
|
||||
return self.processor
|
||||
|
||||
@@ -15,17 +15,23 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
import pickle
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
from paddleformers.transformers import AutoTokenizer
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.engine.request import ImagePosition
|
||||
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
|
||||
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
|
||||
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
||||
from fastdeploy.multimodal.hasher import MultimodalHasher
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
from .image_processor import ImageProcessor
|
||||
from .process_video import read_frames, sample_frames
|
||||
from .process_video import sample_frames
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
@@ -49,8 +55,11 @@ class DataProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
enable_processor_cache: bool = False,
|
||||
video_min_frames: int = 4,
|
||||
video_max_frames: int = 768,
|
||||
video_target_frames: int = -1,
|
||||
video_fps: int = -1,
|
||||
tokens_per_second: int = 2,
|
||||
tokenizer=None,
|
||||
**kwargs,
|
||||
@@ -67,6 +76,8 @@ class DataProcessor:
|
||||
"""
|
||||
self.min_frames = video_min_frames
|
||||
self.max_frames = video_max_frames
|
||||
self.target_frames = video_target_frames
|
||||
self.fps = video_fps
|
||||
|
||||
# Initialize tokenizer with left padding and fast tokenizer
|
||||
if tokenizer is None:
|
||||
@@ -75,6 +86,7 @@ class DataProcessor:
|
||||
else:
|
||||
self.tokenizer = tokenizer
|
||||
self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor
|
||||
self.enable_processor_cache = enable_processor_cache
|
||||
|
||||
# Convolution sizes for patch aggregation
|
||||
self.spatial_conv_size = self.image_processor.merge_size
|
||||
@@ -99,41 +111,7 @@ class DataProcessor:
|
||||
"assistant": "Assistant: ",
|
||||
}
|
||||
|
||||
def _pack_outputs(self, outputs):
|
||||
"""
|
||||
Pack and convert all output data into numpy arrays with appropriate types.
|
||||
|
||||
Args:
|
||||
outputs (dict): Dictionary containing model outputs with keys:
|
||||
- images: List of visual features
|
||||
- grid_thw: List of spatial dimensions
|
||||
- image_type_ids: List of content type indicators
|
||||
- input_ids: List of token IDs
|
||||
- token_type_ids: List of type identifiers
|
||||
- position_ids: List of position embeddings
|
||||
|
||||
Returns:
|
||||
dict: Processed outputs with all values converted to numpy arrays
|
||||
"""
|
||||
# Process visual outputs - stack if exists or set to None if empty
|
||||
if not outputs["images"]:
|
||||
outputs["images"] = None # No images case
|
||||
outputs["grid_thw"] = None # No spatial dimensions
|
||||
outputs["image_type_ids"] = None # No type IDs
|
||||
else:
|
||||
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
|
||||
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
|
||||
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
|
||||
|
||||
# Convert all outputs to numpy arrays with appropriate types
|
||||
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
|
||||
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
|
||||
outputs["position_ids"] = np.concatenate(
|
||||
outputs["position_ids"], axis=1, dtype=np.int64
|
||||
) # Concatenate position IDs
|
||||
return outputs
|
||||
|
||||
def text2ids(self, text, images=None, videos=None):
|
||||
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
|
||||
"""
|
||||
Convert text with image/video placeholders into model inputs.
|
||||
|
||||
@@ -141,6 +119,8 @@ class DataProcessor:
|
||||
text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
|
||||
images: List of PIL Images corresponding to image placeholders
|
||||
videos: List of video data corresponding to video placeholders
|
||||
image_uuid: List of unique identifiers for each image, used for caching or hashing.
|
||||
video_uuid: List of unique identifiers for each video, used for caching or hashing.
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
@@ -161,8 +141,10 @@ class DataProcessor:
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
"fps": [],
|
||||
"mm_positions": [],
|
||||
"mm_hashes": [],
|
||||
}
|
||||
|
||||
# Define placeholders and their lengths
|
||||
@@ -186,23 +168,30 @@ class DataProcessor:
|
||||
break
|
||||
|
||||
if ed == image_pos:
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(images[image_idx], outputs)
|
||||
image = images[image_idx]
|
||||
uuid = image_uuid[image_idx] if image_uuid else None
|
||||
if not isinstance(image, tuple):
|
||||
self._add_image(image, outputs, uuid)
|
||||
else:
|
||||
self._add_processed_image(image, outputs, uuid)
|
||||
image_idx += 1
|
||||
st = ed + IMAGE_PLACEHOLDER_LEN
|
||||
else:
|
||||
item = videos[video_idx]
|
||||
if isinstance(item, dict):
|
||||
frames, meta = self._load_and_process_video(item["video"], item)
|
||||
uuid = video_uuid[video_idx] if video_uuid else None
|
||||
if not isinstance(item, tuple):
|
||||
if isinstance(item, dict):
|
||||
frames, meta = self._load_and_process_video(item["video"], item)
|
||||
else:
|
||||
frames, meta = self._load_and_process_video(item, {})
|
||||
self._add_video(frames, meta, outputs, uuid)
|
||||
else:
|
||||
frames, meta = self._load_and_process_video(item, {})
|
||||
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, meta, outputs)
|
||||
# cached frames are already processed
|
||||
self._add_processed_video(item, outputs, uuid)
|
||||
video_idx += 1
|
||||
st = ed + VIDEO_PLACEHOLDER_LEN
|
||||
|
||||
return self._pack_outputs(outputs)
|
||||
return outputs
|
||||
|
||||
def request2ids(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
@@ -220,74 +209,84 @@ class DataProcessor:
|
||||
Dict with same structure as text2ids() output
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
}
|
||||
|
||||
# Parse and validate chat messages
|
||||
messages = parse_chat_messages(request.get("messages"))
|
||||
image_message_list = [] # Store visual content messages
|
||||
|
||||
mm_items = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
assert role in self.role_prefixes, f"Unsupported role: {role}"
|
||||
|
||||
# Normalize content to list format
|
||||
content_items = msg.get("content")
|
||||
if not isinstance(content_items, list):
|
||||
content_items = [content_items]
|
||||
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
# Collect all visual content items
|
||||
for item in content_items:
|
||||
if isinstance(item, dict) and item.get("type") in ["image", "video"]:
|
||||
image_message_list.append(item)
|
||||
for item in content:
|
||||
if item.get("type") in ["image", "video"]:
|
||||
mm_items.append(item)
|
||||
|
||||
raw_messages = request["messages"]
|
||||
request["messages"] = messages
|
||||
missing_hashes, missing_idx = [], []
|
||||
for idx, item in enumerate(mm_items):
|
||||
if not item.get("data"):
|
||||
# raw data not provided, should be retrieved from processor cache
|
||||
missing_hashes.append(item.get("uuid"))
|
||||
missing_idx.append(idx)
|
||||
|
||||
prompt_token_ids = self.apply_chat_template(request)
|
||||
if len(prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
request["messages"] = raw_messages
|
||||
if len(missing_hashes) > 0 and not self.enable_processor_cache:
|
||||
raise ValueError("Missing items cannot be retrieved without processor cache.")
|
||||
|
||||
vision_start_index = 0
|
||||
vision_message_index = 0
|
||||
for i in range(len(prompt_token_ids)):
|
||||
if prompt_token_ids[i] == self.vision_start_id:
|
||||
self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs)
|
||||
if self.enable_processor_cache:
|
||||
context = zmq.Context()
|
||||
dealer = context.socket(zmq.DEALER)
|
||||
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
|
||||
|
||||
vision_start_index = i + 1
|
||||
image_message = image_message_list[vision_message_index]
|
||||
missing_items = self.get_processor_cache(dealer, missing_hashes)
|
||||
for idx in range(len(missing_items)):
|
||||
if not missing_items[idx]:
|
||||
raise ValueError(f"Missing item {idx} not found in processor cache")
|
||||
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
|
||||
|
||||
if image_message["type"] == "image":
|
||||
img = image_message.get("image")
|
||||
if img is None:
|
||||
continue
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(img, outputs)
|
||||
images, videos = [], []
|
||||
image_uuid, video_uuid = [], []
|
||||
for item in mm_items:
|
||||
if item.get("type") == "image":
|
||||
images.append(item["data"])
|
||||
image_uuid.append(item["uuid"])
|
||||
elif item.get("type") == "video":
|
||||
videos.append(item["data"])
|
||||
video_uuid.append(item["uuid"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
|
||||
|
||||
elif image_message["type"] == "video":
|
||||
video_bytes = image_message.get("video")
|
||||
if video_bytes is None:
|
||||
continue
|
||||
frames, meta = self._load_and_process_video(video_bytes, image_message)
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat template.")
|
||||
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, meta, outputs)
|
||||
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
request["prompt_tokens"] = prompt
|
||||
|
||||
vision_message_index += 1
|
||||
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
|
||||
|
||||
self._add_text(prompt_token_ids[vision_start_index:], outputs)
|
||||
return self._pack_outputs(outputs)
|
||||
if self.enable_processor_cache:
|
||||
missing_idx = set(missing_idx)
|
||||
hashes_to_cache, items_to_cache = [], []
|
||||
for idx in range(len(mm_items)):
|
||||
if idx in missing_idx:
|
||||
continue
|
||||
meta = {}
|
||||
t, h, w = outputs["grid_thw"][idx]
|
||||
meta["thw"] = (t, h, w)
|
||||
meta["fps"] = outputs["fps"][idx]
|
||||
hashes_to_cache.append(outputs["mm_hashes"][idx])
|
||||
items_to_cache.append((outputs["images"][idx], meta))
|
||||
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
|
||||
|
||||
return outputs
|
||||
|
||||
def _add_text(self, tokens, outputs: Dict) -> None:
|
||||
"""
|
||||
@@ -312,9 +311,9 @@ class DataProcessor:
|
||||
outputs["input_ids"].extend(tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
|
||||
|
||||
position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
|
||||
"""
|
||||
@@ -332,7 +331,7 @@ class DataProcessor:
|
||||
position = text_index + start_pos
|
||||
return position
|
||||
|
||||
def _add_image(self, img, outputs: Dict) -> None:
|
||||
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
Add image data to model inputs dictionary.
|
||||
|
||||
@@ -349,20 +348,47 @@ class DataProcessor:
|
||||
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
|
||||
grid_thw = ret["grid_thw"].tolist()
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
outputs["images"].append(ret["pixel_values"])
|
||||
if not uuid:
|
||||
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
|
||||
else:
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(grid_thw)
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
t, h, w = grid_thw
|
||||
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
|
||||
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
|
||||
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
def _add_video(self, frames, meta: Dict, outputs: Dict) -> None:
|
||||
outputs["fps"].append(0)
|
||||
|
||||
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
img, meta = img_cache
|
||||
num_tokens = img.shape[0] // self.image_processor.merge_size**2
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
_, h, w = meta["thw"]
|
||||
pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0)
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
outputs["images"].append(img)
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(np.array([[1, h, w]]))
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
outputs["fps"].append(0)
|
||||
|
||||
def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
Add video data to model inputs dictionary.
|
||||
|
||||
@@ -380,20 +406,49 @@ class DataProcessor:
|
||||
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
|
||||
grid_thw = ret["grid_thw"].tolist()
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
|
||||
outputs["images"].append(ret["pixel_values"])
|
||||
if not uuid:
|
||||
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
|
||||
else:
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(grid_thw)
|
||||
outputs["image_type_ids"].extend([1] * grid_thw[0])
|
||||
|
||||
fps = meta["fps"]
|
||||
second_per_grid_t = self.temporal_conv_size / fps
|
||||
t, h, w = grid_thw
|
||||
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
||||
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
||||
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
outputs["fps"].append(fps)
|
||||
|
||||
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
frames, meta = frames_cache
|
||||
num_tokens = frames.shape[0] // self.image_processor.merge_size**2
|
||||
|
||||
t, h, w = meta["thw"]
|
||||
outputs["images"].append(frames)
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(np.array([[t, h, w]]))
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
outputs["image_type_ids"].extend([1] * t)
|
||||
|
||||
fps = meta["fps"]
|
||||
second_per_grid_t = self.temporal_conv_size / fps
|
||||
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
outputs["fps"].append(fps)
|
||||
|
||||
def _compute_vision_positions(
|
||||
self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
|
||||
@@ -441,20 +496,20 @@ class DataProcessor:
|
||||
- frames: Processed video frames as numpy array
|
||||
- metadata: Updated video metadata dictionary
|
||||
"""
|
||||
frames, meta = read_frames(url)
|
||||
reader, meta, _ = read_video_decord(url, save_to_disk=False)
|
||||
|
||||
# Apply frame sampling if fps or target_frames specified
|
||||
fps = item.get("fps", None)
|
||||
num_frames = item.get("target_frames", None)
|
||||
fps = item.get("fps", self.fps)
|
||||
num_frames = item.get("target_frames", self.target_frames)
|
||||
|
||||
if fps is not None or num_frames is not None:
|
||||
frame_indices = list(range(meta["num_of_frame"]))
|
||||
if fps > 0 or num_frames > 0:
|
||||
# Get frame sampling constraints
|
||||
min_frames = item.get("min_frames", self.min_frames)
|
||||
max_frames = item.get("max_frames", self.max_frames)
|
||||
|
||||
# Sample frames according to specifications
|
||||
frames = sample_frames(
|
||||
video=frames,
|
||||
frame_indices = sample_frames(
|
||||
frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size
|
||||
min_frames=min_frames,
|
||||
max_frames=max_frames,
|
||||
@@ -464,42 +519,38 @@ class DataProcessor:
|
||||
)
|
||||
|
||||
# Update metadata with new frame count and fps
|
||||
meta["num_of_frame"] = frames.shape[0]
|
||||
meta["num_of_frame"] = len(frame_indices)
|
||||
if fps is not None:
|
||||
meta["fps"] = fps # Use specified fps
|
||||
meta["duration"] = frames.shape[0] / fps
|
||||
meta["duration"] = len(frame_indices) / fps
|
||||
else:
|
||||
meta["fps"] = frames.shape[0] / meta["duration"] # Calculate fps from sampled frames
|
||||
meta["fps"] = len(frame_indices) / meta["duration"] # Calculate fps from sampled frames
|
||||
|
||||
frames = []
|
||||
for idx in frame_indices:
|
||||
frame = reader[idx].asnumpy()
|
||||
image = Image.fromarray(frame, "RGB")
|
||||
frames.append(image)
|
||||
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
|
||||
|
||||
return frames, meta
|
||||
|
||||
def apply_chat_template(self, request):
|
||||
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
|
||||
"""
|
||||
Apply chat template to convert messages into token sequence.
|
||||
|
||||
Args:
|
||||
request: Dictionary containing chat messages
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
|
||||
Raises:
|
||||
ValueError: If model doesn't support chat templates
|
||||
get cache correspond to given hash values
|
||||
"""
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
req = pickle.dumps(mm_hashes)
|
||||
socket.send_multipart([b"", req])
|
||||
_, resp = socket.recv_multipart()
|
||||
mm_items = pickle.loads(resp)
|
||||
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
|
||||
|
||||
raw_prompt = self.tokenizer.apply_chat_template(
|
||||
request["messages"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
)
|
||||
prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "")
|
||||
request["prompt_tokens"] = raw_prompt
|
||||
return mm_items
|
||||
|
||||
tokens = self.tokenizer.tokenize(prompt_token_str)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
data_processor_logger.info(
|
||||
f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
return token_ids
|
||||
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
|
||||
"""
|
||||
update cache data
|
||||
"""
|
||||
req = pickle.dumps((mm_hashes, mm_items))
|
||||
socket.send_multipart([b"", req])
|
||||
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")
|
||||
|
||||
@@ -18,50 +18,9 @@ import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
|
||||
|
||||
|
||||
def read_frames(video_path):
|
||||
"""
|
||||
Read and decode video frames from the given path
|
||||
|
||||
This function reads a video file and decodes it into individual RGB frames
|
||||
using decord video reader. It also extracts video metadata including fps,
|
||||
duration and frame count.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to the video file or bytes object containing video data
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
frames (numpy.ndarray): Array of shape (num_frames, height, width, 3)
|
||||
containing decoded RGB video frames
|
||||
meta (dict): Dictionary containing video metadata:
|
||||
- fps (float): Frames per second
|
||||
- duration (float): Video duration in seconds
|
||||
- num_of_frame (int): Total number of frames
|
||||
- width (int): Frame width in pixels
|
||||
- height (int): Frame height in pixels
|
||||
|
||||
Note:
|
||||
- The function uses decord library for efficient video reading
|
||||
- All frames are converted to RGB format regardless of input format
|
||||
"""
|
||||
reader, meta, _ = read_video_decord(video_path, save_to_disk=False)
|
||||
|
||||
frames = []
|
||||
for i in range(meta["num_of_frame"]):
|
||||
frame = reader[i].asnumpy()
|
||||
image = Image.fromarray(frame, "RGB")
|
||||
frames.append(image)
|
||||
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
|
||||
return frames, meta
|
||||
|
||||
|
||||
def sample_frames(
|
||||
video: np.ndarray,
|
||||
frame_factor: int,
|
||||
min_frames: int,
|
||||
max_frames: int,
|
||||
@@ -73,7 +32,6 @@ def sample_frames(
|
||||
Sample frames from video according to specified criteria.
|
||||
|
||||
Args:
|
||||
video: Input video frames as numpy array
|
||||
frame_factor: Ensure sampled frames are multiples of this factor
|
||||
min_frames: Minimum number of frames to sample
|
||||
max_frames: Maximum number of frames to sample
|
||||
@@ -89,18 +47,15 @@ def sample_frames(
|
||||
or if required metadata is missing,
|
||||
or if requested frames exceed available frames
|
||||
"""
|
||||
if fps is not None and num_frames is not None:
|
||||
if fps > 0 and num_frames > 0:
|
||||
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
||||
|
||||
if fps is None and num_frames is None:
|
||||
return video
|
||||
|
||||
total_num_frames = video.shape[0]
|
||||
total_num_frames = metadata["num_of_frame"]
|
||||
|
||||
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||
if num_frames is not None:
|
||||
if num_frames > 0:
|
||||
num_frames = round(num_frames / frame_factor) * frame_factor
|
||||
elif fps is not None:
|
||||
elif fps > 0:
|
||||
if metadata is None:
|
||||
raise ValueError(
|
||||
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
|
||||
@@ -110,7 +65,6 @@ def sample_frames(
|
||||
num_frames = total_num_frames / metadata["fps"] * fps
|
||||
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
|
||||
num_frames = math.floor(num_frames / frame_factor) * frame_factor
|
||||
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
|
||||
@@ -118,14 +72,11 @@ def sample_frames(
|
||||
)
|
||||
|
||||
# Calculate frame indices based on sampling strategy
|
||||
if num_frames is not None:
|
||||
if num_frames > 0:
|
||||
# Evenly spaced sampling for target frame count
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32)
|
||||
else:
|
||||
# Keep all frames if no sampling requested
|
||||
indices = np.arange(0, total_num_frames).astype(np.int32)
|
||||
|
||||
# Apply frame selection
|
||||
video = video[indices]
|
||||
|
||||
return video
|
||||
return indices
|
||||
|
||||
@@ -47,6 +47,7 @@ class QwenVLProcessor(TextProcessor):
|
||||
mm_processor_kwargs=None,
|
||||
reasoning_parser_obj=None,
|
||||
tool_parser_obj=None,
|
||||
enable_processor_cache=False,
|
||||
):
|
||||
"""
|
||||
Initialize QwenVLProcessor instance.
|
||||
@@ -65,6 +66,7 @@ class QwenVLProcessor(TextProcessor):
|
||||
processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)
|
||||
self.processor = DataProcessor(
|
||||
model_path=model_name_or_path,
|
||||
enable_processor_cache=enable_processor_cache,
|
||||
tokens_per_second=config.vision_config.tokens_per_second,
|
||||
tokenizer=self.tokenizer,
|
||||
**processor_kwargs,
|
||||
@@ -271,7 +273,7 @@ class QwenVLProcessor(TextProcessor):
|
||||
|
||||
return request
|
||||
|
||||
def append_completion_tokens(self, outputs, completion_token_ids):
|
||||
def append_completion_tokens(self, multimodal_inputs, completion_token_ids):
|
||||
"""
|
||||
Append completion tokens to existing outputs.
|
||||
|
||||
@@ -279,19 +281,14 @@ class QwenVLProcessor(TextProcessor):
|
||||
outputs: Current model outputs
|
||||
completion_token_ids: completion tokens to append
|
||||
"""
|
||||
out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]}
|
||||
self.processor._add_text(completion_token_ids, out)
|
||||
|
||||
outputs["input_ids"] = np.concatenate(
|
||||
[outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0
|
||||
)
|
||||
outputs["token_type_ids"] = np.concatenate(
|
||||
[outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0
|
||||
)
|
||||
outputs["position_ids"] = np.concatenate(
|
||||
[outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64
|
||||
)
|
||||
outputs["cur_position"] = out["cur_position"]
|
||||
num_tokens = len(completion_token_ids)
|
||||
multimodal_inputs["input_ids"].extend(completion_token_ids)
|
||||
multimodal_inputs["token_type_ids"].extend([0] * num_tokens)
|
||||
|
||||
pos_ids = self.processor._compute_text_positions(multimodal_inputs["cur_position"], num_tokens)
|
||||
multimodal_inputs["position_ids"].append(pos_ids)
|
||||
multimodal_inputs["cur_position"] += num_tokens
|
||||
|
||||
def pack_outputs(self, outputs):
|
||||
"""
|
||||
@@ -303,7 +300,24 @@ class QwenVLProcessor(TextProcessor):
|
||||
Returns:
|
||||
dict: Packed output dictionary with all required fields
|
||||
"""
|
||||
if not outputs["images"]:
|
||||
outputs["images"] = None # No images case
|
||||
outputs["grid_thw"] = None # No spatial dimensions
|
||||
outputs["image_type_ids"] = None # No type IDs
|
||||
else:
|
||||
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
|
||||
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
|
||||
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
|
||||
|
||||
# Convert all outputs to numpy arrays with appropriate types
|
||||
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
|
||||
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
|
||||
outputs["position_ids"] = np.concatenate(
|
||||
outputs["position_ids"], axis=1, dtype=np.int64
|
||||
) # Concatenate position ID
|
||||
|
||||
outputs["image_patch_id"] = self.processor.image_token_id
|
||||
outputs["video_patch_id"] = self.processor.video_token_id
|
||||
outputs["position_ids"] = outputs["position_ids"].transpose(1, 0)
|
||||
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user