mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* support mm prefix caching * update code * fix mm_hashes * support encoder cache * add encoder cache * update code * update encoder cache * fix features bug * fix worker bug * support processor cache, need to optimize yet * refactor multimodal data cache * update code * update code * update v1 scheduler * update code * update code * update codestyle * support turn off processor cache and encoder cache * update pre-commit * fix code * solve review * update code * update code * update test case * set processor cache in GiB * update test case * support mm prefix caching for qwen model * fix code style check * update pre-commit * fix unit test * fix unit test * add ci test case * fix rescheduled bug * change text_after_process to prompt_tokens * fix unit test * fix chat template * change model path * [EP] fix adapter bugs (#4572) * Update expert_service.py * Update common_engine.py * Update expert_service.py * fix v1 hang bug (#4573) * fix import image_ops error on some platforms (#4559) * [CLI]Update parameters in bench latecy cli tool and fix collect-env cli tool (#4558) * add collect-env * del files * [Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578) * add new branch for sot * reorder * fix batch bug * [XPU]Moe uses a new operator (#4585) * [XPU]Moe uses a new operator * [XPU]Moe uses a new operator * update response * [Feature] Support Paddle-OCR (#4396) * init * update code * fix code style & disable thinking * adapt for common_engine.update_mm_requests_chunk_size * use 3d rope * use flash_attn_unpadded * opt siglip * update to be compatible with the latest codebase * fix typo * optim OCR performance * fix bug * fix bug * fix bug * fix bug * normlize name * modify xpu rope * revert logger * fix bug * fix bug * fix bug * support default_v1 * optim performance * fix bug --------- Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com> * [DataProcessor] add reasoning_tokens into usage info (#4520) * add reasoning_tokens into usage info initial commit * add unit tests * modify unit test * modify and add unit tests * fix unit test * move steam usage to processor * modify processor * modify test_logprobs * modify test_logprobs.py * modify stream reasoning tokens accumulation * fix unit test * perf: Optimize task queue communication from engine to worker (#4531) * perf: Optimize task queue communication from engine to worker * perf: get_tasks to numpy * perf: get_tasks remove to_numpy * fix: request & replace ENV * remove test_e2w_perf.py * fix code style --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> * Clean up ports after processing results (#4587) * [CI] Add /re-run command in PR comments to restart failed CI workflows (#4593) * [Others] api server exits when worker process is dead (#3271) * [fix] fix terminal hangs when worker process is dead * [chore] change sleep time of monitor * [chore] remove redundant comments * update docs --------- Co-authored-by: ApplEOFDiscord <wwy640130@163.com> Co-authored-by: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com> Co-authored-by: yinwei <yinwei_hust@163.com> Co-authored-by: JYChen <zoooo0820@qq.com> Co-authored-by: qwes5s5 <45442318+qwes5s5@users.noreply.github.com> Co-authored-by: Ryan <zihaohuang@aliyun.com> Co-authored-by: yyssys <atyangshuang@foxmail.com> Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com> Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com> Co-authored-by: kxz2002 <115912648+kxz2002@users.noreply.github.com> Co-authored-by: SunLei <sunlei5788@gmail.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
557 lines
21 KiB
Python
557 lines
21 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import pickle
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import zmq
|
|
from paddleformers.transformers import AutoTokenizer
|
|
from PIL import Image
|
|
|
|
from fastdeploy.engine.request import ImagePosition
|
|
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
|
|
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
|
|
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
|
from fastdeploy.multimodal.hasher import MultimodalHasher
|
|
from fastdeploy.utils import data_processor_logger
|
|
|
|
from .image_processor import ImageProcessor
|
|
from .process_video import sample_frames
|
|
|
|
|
|
class DataProcessor:
|
|
"""
|
|
Processes multimodal inputs (text, images, videos) into model-ready formats.
|
|
|
|
Handles:
|
|
- Tokenization of text with special tokens for visual content
|
|
- Image and video preprocessing
|
|
- Generation of 3D positional embeddings
|
|
- Conversion of chat messages to model inputs
|
|
|
|
Attributes:
|
|
tokenizer: Text tokenizer instance
|
|
image_processor: Image/video preprocessor
|
|
image_token: Special token for image placeholders
|
|
video_token: Special token for video placeholders
|
|
vision_start: Token marking start of visual content
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
enable_processor_cache: bool = False,
|
|
video_min_frames: int = 4,
|
|
video_max_frames: int = 768,
|
|
video_target_frames: int = -1,
|
|
video_fps: int = -1,
|
|
tokens_per_second: int = 2,
|
|
tokenizer=None,
|
|
**kwargs,
|
|
) -> None:
|
|
"""
|
|
Initialize the data processor.
|
|
|
|
Args:
|
|
model_path: Path to pretrained model
|
|
video_min_frames: Minimum frames to sample from videos
|
|
video_max_frames: Maximum frames to sample from videos
|
|
tokens_per_second: Temporal resolution for positional embeddings
|
|
**kwargs: Additional configuration
|
|
"""
|
|
self.min_frames = video_min_frames
|
|
self.max_frames = video_max_frames
|
|
self.target_frames = video_target_frames
|
|
self.fps = video_fps
|
|
|
|
# Initialize tokenizer with left padding and fast tokenizer
|
|
if tokenizer is None:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left", use_fast=True)
|
|
self.tokenizer.ignored_index = -100 # Set ignored index for loss calculation
|
|
else:
|
|
self.tokenizer = tokenizer
|
|
self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor
|
|
self.enable_processor_cache = enable_processor_cache
|
|
|
|
# Convolution sizes for patch aggregation
|
|
self.spatial_conv_size = self.image_processor.merge_size
|
|
self.temporal_conv_size = self.image_processor.temporal_patch_size
|
|
|
|
# Special tokens and IDs
|
|
self.image_token = "<|image_pad|>"
|
|
self.video_token = "<|video_pad|>"
|
|
|
|
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
|
|
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
|
|
|
|
self.vision_start = "<|vision_start|>"
|
|
self.vision_start_id = self.tokenizer.convert_tokens_to_ids(self.vision_start)
|
|
|
|
self.tokens_per_second = tokens_per_second
|
|
|
|
self.role_prefixes = {
|
|
"system": "",
|
|
"user": "User: ",
|
|
"bot": "Assistant: ",
|
|
"assistant": "Assistant: ",
|
|
}
|
|
|
|
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
|
|
"""
|
|
Convert text with image/video placeholders into model inputs.
|
|
|
|
Args:
|
|
text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
|
|
images: List of PIL Images corresponding to image placeholders
|
|
videos: List of video data corresponding to video placeholders
|
|
image_uuid: List of unique identifiers for each image, used for caching or hashing.
|
|
video_uuid: List of unique identifiers for each video, used for caching or hashing.
|
|
|
|
Returns:
|
|
Dict containing:
|
|
- input_ids: Token IDs
|
|
- token_type_ids: Type identifiers (text/image/video)
|
|
- position_ids: 3D positional embeddings
|
|
- images: Preprocessed visual features
|
|
- grid_thw: Spatial/temporal dimensions
|
|
- image_type_ids: Visual content type (0=image, 1=video)
|
|
"""
|
|
|
|
outputs = {
|
|
"input_ids": [],
|
|
"token_type_ids": [],
|
|
"position_ids": [],
|
|
"images": [],
|
|
"grid_thw": [],
|
|
"image_type_ids": [],
|
|
"labels": [],
|
|
"cur_position": 0,
|
|
"video_cnt": 0,
|
|
"fps": [],
|
|
"mm_positions": [],
|
|
"mm_hashes": [],
|
|
}
|
|
|
|
# Define placeholders and their lengths
|
|
IMAGE_PLACEHOLDER = "<|image_pad|>"
|
|
VIDEO_PLACEHOLDER = "<|video_pad|>"
|
|
IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
|
|
VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
|
|
|
|
# Initialize tracking variables for text parsing
|
|
st, image_idx, video_idx = 0, 0, 0 # Start position, image counter, video counter
|
|
while st < len(text):
|
|
# Find next image or video placeholder in text
|
|
image_pos = text.find(IMAGE_PLACEHOLDER, st)
|
|
image_pos = len(text) if image_pos == -1 else image_pos # Set to end if not found
|
|
video_pos = text.find(VIDEO_PLACEHOLDER, st)
|
|
video_pos = len(text) if video_pos == -1 else video_pos # Set to end if not found
|
|
ed = min(image_pos, video_pos) # End position is first placeholder found
|
|
|
|
self._add_text(text[st:ed], outputs)
|
|
if ed == len(text):
|
|
break
|
|
|
|
if ed == image_pos:
|
|
image = images[image_idx]
|
|
uuid = image_uuid[image_idx] if image_uuid else None
|
|
if not isinstance(image, tuple):
|
|
self._add_image(image, outputs, uuid)
|
|
else:
|
|
self._add_processed_image(image, outputs, uuid)
|
|
image_idx += 1
|
|
st = ed + IMAGE_PLACEHOLDER_LEN
|
|
else:
|
|
item = videos[video_idx]
|
|
uuid = video_uuid[video_idx] if video_uuid else None
|
|
if not isinstance(item, tuple):
|
|
if isinstance(item, dict):
|
|
frames, meta = self._load_and_process_video(item["video"], item)
|
|
else:
|
|
frames, meta = self._load_and_process_video(item, {})
|
|
self._add_video(frames, meta, outputs, uuid)
|
|
else:
|
|
# cached frames are already processed
|
|
self._add_processed_video(item, outputs, uuid)
|
|
video_idx += 1
|
|
st = ed + VIDEO_PLACEHOLDER_LEN
|
|
|
|
return outputs
|
|
|
|
def request2ids(
|
|
self, request: Dict[str, Any], tgts: List[str] = None
|
|
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
|
"""
|
|
Convert chat request with multimodal messages into model inputs.
|
|
|
|
Args:
|
|
request: Dictionary containing:
|
|
- messages: List of chat messages with text/image/video content
|
|
- request_id: Unique identifier for logging
|
|
tgts: Optional target sequences
|
|
|
|
Returns:
|
|
Dict with same structure as text2ids() output
|
|
"""
|
|
|
|
# Parse and validate chat messages
|
|
messages = parse_chat_messages(request.get("messages"))
|
|
mm_items = []
|
|
for msg in messages:
|
|
role = msg.get("role")
|
|
assert role in self.role_prefixes, f"Unsupported role: {role}"
|
|
|
|
# Normalize content to list format
|
|
content = msg.get("content")
|
|
if not isinstance(content, list):
|
|
content = [content]
|
|
# Collect all visual content items
|
|
for item in content:
|
|
if item.get("type") in ["image", "video"]:
|
|
mm_items.append(item)
|
|
|
|
missing_hashes, missing_idx = [], []
|
|
for idx, item in enumerate(mm_items):
|
|
if not item.get("data"):
|
|
# raw data not provided, should be retrieved from processor cache
|
|
missing_hashes.append(item.get("uuid"))
|
|
missing_idx.append(idx)
|
|
|
|
if len(missing_hashes) > 0 and not self.enable_processor_cache:
|
|
raise ValueError("Missing items cannot be retrieved without processor cache.")
|
|
|
|
if self.enable_processor_cache:
|
|
context = zmq.Context()
|
|
dealer = context.socket(zmq.DEALER)
|
|
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
|
|
|
|
missing_items = self.get_processor_cache(dealer, missing_hashes)
|
|
for idx in range(len(missing_items)):
|
|
if not missing_items[idx]:
|
|
raise ValueError(f"Missing item {idx} not found in processor cache")
|
|
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
|
|
|
|
images, videos = [], []
|
|
image_uuid, video_uuid = [], []
|
|
for item in mm_items:
|
|
if item.get("type") == "image":
|
|
images.append(item["data"])
|
|
image_uuid.append(item["uuid"])
|
|
elif item.get("type") == "video":
|
|
videos.append(item["data"])
|
|
video_uuid.append(item["uuid"])
|
|
else:
|
|
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
|
|
|
|
if self.tokenizer.chat_template is None:
|
|
raise ValueError("This model does not support chat template.")
|
|
|
|
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
|
prompt = self.tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=request.get("add_generation_prompt", True),
|
|
**chat_template_kwargs,
|
|
)
|
|
request["prompt_tokens"] = prompt
|
|
|
|
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
|
|
|
|
if self.enable_processor_cache:
|
|
missing_idx = set(missing_idx)
|
|
hashes_to_cache, items_to_cache = [], []
|
|
for idx in range(len(mm_items)):
|
|
if idx in missing_idx:
|
|
continue
|
|
meta = {}
|
|
t, h, w = outputs["grid_thw"][idx]
|
|
meta["thw"] = (t, h, w)
|
|
meta["fps"] = outputs["fps"][idx]
|
|
hashes_to_cache.append(outputs["mm_hashes"][idx])
|
|
items_to_cache.append((outputs["images"][idx], meta))
|
|
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
|
|
|
|
return outputs
|
|
|
|
def _add_text(self, tokens, outputs: Dict) -> None:
|
|
"""
|
|
Add text tokens to model inputs dictionary.
|
|
|
|
Args:
|
|
tokens: Text string or already tokenized IDs
|
|
outputs: Dictionary accumulating model inputs
|
|
|
|
Note:
|
|
- Handles both raw text and pre-tokenized inputs
|
|
- Updates position IDs for 3D embeddings
|
|
"""
|
|
if not tokens:
|
|
return None
|
|
|
|
if isinstance(tokens, str):
|
|
tokens_str = self.tokenizer.tokenize(tokens)
|
|
tokens = self.tokenizer.convert_tokens_to_ids(tokens_str)
|
|
|
|
num_tokens = len(tokens)
|
|
outputs["input_ids"].extend(tokens)
|
|
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
|
|
|
|
pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
|
|
outputs["position_ids"].append(pos_ids)
|
|
outputs["cur_position"] = pos_ids.max() + 1
|
|
|
|
def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
|
|
"""
|
|
Generate 3D positional embeddings for text tokens.
|
|
|
|
Args:
|
|
start_pos: Starting position index
|
|
num_tokens: Number of tokens to generate positions for
|
|
|
|
Returns:
|
|
numpy.ndarray: 3D position IDs shaped (3, num_tokens)
|
|
"""
|
|
text_array = np.arange(num_tokens).reshape(1, -1)
|
|
text_index = np.broadcast_to(text_array, (3, num_tokens))
|
|
position = text_index + start_pos
|
|
return position
|
|
|
|
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
|
|
"""
|
|
Add image data to model inputs dictionary.
|
|
|
|
Args:
|
|
img: PIL Image to process
|
|
outputs: Dictionary accumulating model inputs
|
|
|
|
Note:
|
|
- Preprocesses image and calculates spatial dimensions
|
|
- Adds image token IDs and type markers
|
|
- Generates appropriate position embeddings
|
|
"""
|
|
ret = self.image_processor.preprocess(images=[img.convert("RGB")])
|
|
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
|
|
grid_thw = ret["grid_thw"].tolist()
|
|
|
|
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
|
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
|
|
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
|
|
|
outputs["images"].append(ret["pixel_values"])
|
|
if not uuid:
|
|
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
|
|
else:
|
|
outputs["mm_hashes"].append(uuid)
|
|
outputs["grid_thw"].append(grid_thw)
|
|
outputs["image_type_ids"].append(0)
|
|
|
|
t, h, w = grid_thw
|
|
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
|
|
|
|
outputs["position_ids"].append(pos_ids)
|
|
outputs["cur_position"] = pos_ids.max() + 1
|
|
|
|
outputs["fps"].append(0)
|
|
|
|
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
|
img, meta = img_cache
|
|
num_tokens = img.shape[0] // self.image_processor.merge_size**2
|
|
|
|
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
|
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
|
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
|
|
|
_, h, w = meta["thw"]
|
|
pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0)
|
|
outputs["position_ids"].append(pos_ids)
|
|
outputs["cur_position"] = pos_ids.max() + 1
|
|
|
|
outputs["images"].append(img)
|
|
outputs["mm_hashes"].append(uuid)
|
|
outputs["grid_thw"].append(np.array([[1, h, w]]))
|
|
outputs["image_type_ids"].append(0)
|
|
|
|
outputs["fps"].append(0)
|
|
|
|
def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) -> None:
|
|
"""
|
|
Add video data to model inputs dictionary.
|
|
|
|
Args:
|
|
frames: Video frames as numpy array
|
|
meta: Video metadata containing fps/duration
|
|
outputs: Dictionary accumulating model inputs
|
|
|
|
Note:
|
|
- Handles temporal dimension in position embeddings
|
|
- Uses video-specific token IDs and type markers
|
|
"""
|
|
ret = self.image_processor.preprocess(images=frames)
|
|
|
|
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
|
|
grid_thw = ret["grid_thw"].tolist()
|
|
|
|
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
|
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
|
|
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
|
|
|
outputs["images"].append(ret["pixel_values"])
|
|
if not uuid:
|
|
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
|
|
else:
|
|
outputs["mm_hashes"].append(uuid)
|
|
outputs["grid_thw"].append(grid_thw)
|
|
outputs["image_type_ids"].extend([1] * grid_thw[0])
|
|
|
|
fps = meta["fps"]
|
|
second_per_grid_t = self.temporal_conv_size / fps
|
|
t, h, w = grid_thw
|
|
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
|
|
|
outputs["position_ids"].append(pos_ids)
|
|
outputs["cur_position"] = pos_ids.max() + 1
|
|
|
|
outputs["fps"].append(fps)
|
|
|
|
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
|
frames, meta = frames_cache
|
|
num_tokens = frames.shape[0] // self.image_processor.merge_size**2
|
|
|
|
t, h, w = meta["thw"]
|
|
outputs["images"].append(frames)
|
|
outputs["mm_hashes"].append(uuid)
|
|
outputs["grid_thw"].append(np.array([[t, h, w]]))
|
|
|
|
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
|
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
|
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
|
outputs["image_type_ids"].extend([1] * t)
|
|
|
|
fps = meta["fps"]
|
|
second_per_grid_t = self.temporal_conv_size / fps
|
|
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
|
outputs["position_ids"].append(pos_ids)
|
|
outputs["cur_position"] = pos_ids.max() + 1
|
|
|
|
outputs["fps"].append(fps)
|
|
|
|
def _compute_vision_positions(
|
|
self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
|
|
) -> np.ndarray:
|
|
"""
|
|
Generate 3D position IDs for visual inputs.
|
|
|
|
Args:
|
|
start_pos: Base position in sequence
|
|
t: Temporal patches (1 for images)
|
|
h: Height in patches
|
|
w: Width in patches
|
|
second_per_grid_t: Time per temporal patch
|
|
|
|
Returns:
|
|
np.ndarray: Position IDs for [t,h,w] dimensions
|
|
"""
|
|
h //= self.spatial_conv_size
|
|
w //= self.spatial_conv_size
|
|
|
|
tn = np.arange(t).reshape(-1, 1)
|
|
tn = np.broadcast_to(tn, (t, h * w))
|
|
tn = tn * int(second_per_grid_t) * self.tokens_per_second
|
|
t_index = tn.flatten()
|
|
|
|
hn = np.arange(h).reshape(1, -1, 1)
|
|
h_index = np.broadcast_to(hn, (t, h, w)).flatten()
|
|
|
|
wn = np.arange(w).reshape(1, 1, -1)
|
|
w_index = np.broadcast_to(wn, (t, h, w)).flatten()
|
|
|
|
position = np.stack([t_index, h_index, w_index]) + start_pos
|
|
return position
|
|
|
|
def _load_and_process_video(self, url: str, item: Dict) -> Tuple[np.ndarray, Dict]:
|
|
"""
|
|
Load and preprocess video into frames.
|
|
|
|
Args:
|
|
url: Video file path or bytes
|
|
item: Dictionary containing processing parameters
|
|
|
|
Returns:
|
|
tuple: (frames, metadata) where:
|
|
- frames: Processed video frames as numpy array
|
|
- metadata: Updated video metadata dictionary
|
|
"""
|
|
reader, meta, _ = read_video_decord(url, save_to_disk=False)
|
|
|
|
# Apply frame sampling if fps or target_frames specified
|
|
fps = item.get("fps", self.fps)
|
|
num_frames = item.get("target_frames", self.target_frames)
|
|
|
|
frame_indices = list(range(meta["num_of_frame"]))
|
|
if fps > 0 or num_frames > 0:
|
|
# Get frame sampling constraints
|
|
min_frames = item.get("min_frames", self.min_frames)
|
|
max_frames = item.get("max_frames", self.max_frames)
|
|
|
|
# Sample frames according to specifications
|
|
frame_indices = sample_frames(
|
|
frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size
|
|
min_frames=min_frames,
|
|
max_frames=max_frames,
|
|
metadata=meta,
|
|
fps=fps,
|
|
num_frames=num_frames,
|
|
)
|
|
|
|
# Update metadata with new frame count and fps
|
|
meta["num_of_frame"] = len(frame_indices)
|
|
if fps is not None:
|
|
meta["fps"] = fps # Use specified fps
|
|
meta["duration"] = len(frame_indices) / fps
|
|
else:
|
|
meta["fps"] = len(frame_indices) / meta["duration"] # Calculate fps from sampled frames
|
|
|
|
frames = []
|
|
for idx in frame_indices:
|
|
frame = reader[idx].asnumpy()
|
|
image = Image.fromarray(frame, "RGB")
|
|
frames.append(image)
|
|
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
|
|
|
|
return frames, meta
|
|
|
|
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
|
|
"""
|
|
get cache correspond to given hash values
|
|
"""
|
|
req = pickle.dumps(mm_hashes)
|
|
socket.send_multipart([b"", req])
|
|
_, resp = socket.recv_multipart()
|
|
mm_items = pickle.loads(resp)
|
|
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
|
|
|
|
return mm_items
|
|
|
|
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
|
|
"""
|
|
update cache data
|
|
"""
|
|
req = pickle.dumps((mm_hashes, mm_items))
|
|
socket.send_multipart([b"", req])
|
|
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")
|