[Feature] mm support prefix cache (#4134)

* support mm prefix caching

* update code

* fix mm_hashes

* support encoder cache

* add encoder cache

* update code

* update encoder cache

* fix features bug

* fix worker bug

* support processor cache, need to optimize yet

* refactor multimodal data cache

* update code

* update code

* update v1 scheduler

* update code

* update code

* update codestyle

* support turn off processor cache and encoder cache

* update pre-commit

* fix code

* solve review

* update code

* update code

* update test case

* set processor cache in GiB

* update test case

* support mm prefix caching for qwen model

* fix code style check

* update pre-commit

* fix unit test

* fix unit test

* add ci test case

* fix rescheduled bug

* change text_after_process to prompt_tokens

* fix unit test

* fix chat template

* change model path

* [EP] fix adapter bugs (#4572)

* Update expert_service.py

* Update common_engine.py

* Update expert_service.py

* fix v1 hang bug (#4573)

* fix import image_ops error on some platforms (#4559)

* [CLI]Update parameters in bench latecy cli tool and fix collect-env cli tool (#4558)

* add collect-env

* del files

* [Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578)

* add new branch for sot

* reorder

* fix batch bug

* [XPU]Moe uses a new operator (#4585)

* [XPU]Moe uses a new operator

* [XPU]Moe uses a new operator

* update response

* [Feature] Support Paddle-OCR (#4396)

* init

* update code

* fix code style & disable thinking

* adapt for common_engine.update_mm_requests_chunk_size

* use 3d rope

* use flash_attn_unpadded

* opt siglip

* update to be compatible with the latest codebase

* fix typo

* optim OCR performance

* fix bug

* fix bug

* fix bug

* fix bug

* normlize name

* modify xpu rope

* revert logger

* fix bug

* fix bug

* fix bug

* support default_v1

* optim performance

* fix bug

---------

Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>

* [DataProcessor] add reasoning_tokens into usage info (#4520)

* add reasoning_tokens into usage info initial commit

* add unit tests

* modify unit test

* modify and add unit tests

* fix unit test

* move steam usage to processor

* modify processor

* modify test_logprobs

* modify test_logprobs.py

* modify stream reasoning tokens accumulation

* fix unit test

* perf: Optimize task queue communication from engine to worker (#4531)

* perf: Optimize task queue communication from engine to worker

* perf: get_tasks to numpy

* perf: get_tasks remove to_numpy

* fix: request & replace ENV

* remove test_e2w_perf.py

* fix code style

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* Clean up ports after processing results (#4587)

* [CI] Add /re-run command in PR comments to restart failed CI workflows (#4593)

* [Others] api server exits when worker process is dead (#3271)

* [fix] fix terminal hangs when worker process is dead

* [chore] change sleep time of monitor

* [chore] remove redundant comments

* update docs

---------

Co-authored-by: ApplEOFDiscord <wwy640130@163.com>
Co-authored-by: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
Co-authored-by: yinwei <yinwei_hust@163.com>
Co-authored-by: JYChen <zoooo0820@qq.com>
Co-authored-by: qwes5s5 <45442318+qwes5s5@users.noreply.github.com>
Co-authored-by: Ryan <zihaohuang@aliyun.com>
Co-authored-by: yyssys <atyangshuang@foxmail.com>
Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>
Co-authored-by: kxz2002 <115912648+kxz2002@users.noreply.github.com>
Co-authored-by: SunLei <sunlei5788@gmail.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
This commit is contained in:
kevin
2025-10-27 17:39:51 +08:00
committed by GitHub
parent a4fb3d4ff0
commit 8aab4e367f
40 changed files with 1741 additions and 545 deletions

View File

@@ -37,6 +37,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
mm_processor_kwargs=None,
reasoning_parser_obj=None,
tool_parser_obj=None,
enable_processor_cache=False,
):
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
tokenizer_path = model_name_or_path
@@ -46,6 +47,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
self.ernie4_5_processor = DataProcessor(
tokenizer_name=tokenizer_path,
image_preprocessor_name=preprocessor_path,
enable_processor_cache=enable_processor_cache,
**processor_kwargs,
)
self.ernie4_5_processor.eval()

View File

@@ -18,16 +18,20 @@
""" process.py """
import copy
import os
import pickle
from collections import defaultdict
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import zmq
from paddleformers.transformers.image_utils import ChannelDimension
from PIL import Image
from fastdeploy.engine.request import ImagePosition
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.utils import data_processor_logger
from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor
@@ -84,6 +88,7 @@ class DataProcessor:
self,
tokenizer_name: str,
image_preprocessor_name: str,
enable_processor_cache: bool = False,
spatial_conv_size: int = 2,
temporal_conv_size: int = 2,
image_min_pixels: int = 4 * 28 * 28,
@@ -102,6 +107,7 @@ class DataProcessor:
self._load_tokenizer()
self.tokenizer.ignored_index = -100
self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
self.enable_processor_cache = enable_processor_cache
# Convolution sizes for patch aggregation
self.spatial_conv_size = spatial_conv_size
@@ -163,10 +169,18 @@ class DataProcessor:
"""Enable evaluation mode (doesn't produce labels)."""
self.is_training = False
def text2ids(self, text, images=None, videos=None):
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
"""
Convert chat text into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
Args:
text (str): The chat text containing placeholders for images and videos.
images (list, optional): List of images to be processed and inserted at image placeholders.
videos (list, optional): List of videos to be processed and inserted at video placeholders.
image_uuid (list, optional): List of unique identifiers for each image, used for caching or hashing.
video_uuid (list, optional): List of unique identifiers for each video, used for caching or hashing.
Returns:
dict: A dictionary with keys input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels, etc.
"""
outputs = {
@@ -178,8 +192,9 @@ class DataProcessor:
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
"mm_positions": [],
"mm_hashes": [],
}
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
@@ -199,17 +214,27 @@ class DataProcessor:
break
if ed == image_pos:
self._add_image(images[image_idx], outputs)
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid)
else:
# cached images are already processed
self._add_processed_image(image, outputs, uuid)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
if isinstance(item, dict):
frames = self._load_and_process_video(item["video"], item)
uuid = video_uuid[video_idx] if video_uuid else None
if not isinstance(item, tuple):
if isinstance(item, dict):
frames = self._load_and_process_video(item["video"], item)
else:
frames = self._load_and_process_video(item, {})
self._add_video(frames, outputs, uuid)
else:
frames = self._load_and_process_video(item, {})
self._add_video(frames, outputs)
# cached frames are already processed
self._add_processed_video(item, outputs, uuid)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
@@ -223,66 +248,82 @@ class DataProcessor:
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
messages = parse_chat_messages(request.get("messages"))
image_message_list = []
mm_items = []
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
content_items = msg.get("content")
if not isinstance(content_items, list):
content_items = [content_items]
for item in content_items:
if isinstance(item, dict) and item.get("type") in [
"image",
"video",
]:
image_message_list.append(item)
content = msg.get("content")
if not isinstance(content, list):
content = [content]
for item in content:
if item.get("type") in ["image", "video"]:
mm_items.append(item)
missing_hashes, missing_idx = [], []
for idx, item in enumerate(mm_items):
if not item.get("data"):
# raw data not provided, should be retrieved from processor cache
missing_hashes.append(item.get("uuid"))
missing_idx.append(idx)
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
missing_items = self.get_processor_cache(dealer, missing_hashes)
for idx in range(len(missing_items)):
if not missing_items[idx]:
raise ValueError(f"Missing item {idx} not found in processor cache")
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
images, videos = [], []
image_uuid, video_uuid = [], []
for item in mm_items:
if item.get("type") == "image":
images.append(item["data"])
image_uuid.append(item["uuid"])
elif item.get("type") == "video":
videos.append(item["data"])
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
chat_template_kwargs = request.get("chat_template_kwargs", {})
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
if len(prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
image_start_index = 0
image_message_index = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in [
self.image_start_id,
self.video_start_id,
]:
self._add_text(prompt_token_ids[image_start_index : i + 1], outputs)
image_start_index = i + 1
image_message = image_message_list[image_message_index]
if image_message["type"] == "image":
img = image_message.get("image")
if img is None:
continue
outputs["pic_cnt"] += 1
self._add_image(img, outputs)
elif image_message["type"] == "video":
video_bytes = image_message.get("video")
if video_bytes is None:
continue
frames = self._load_and_process_video(video_bytes, image_message)
outputs["video_cnt"] += 1
self._add_video(frames, outputs)
image_message_index += 1
self._add_text(prompt_token_ids[image_start_index:], outputs)
prompt = self.tokenizer.apply_chat_template(
request,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**chat_template_kwargs,
)
request["prompt_tokens"] = prompt
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx][0]
meta["thw"] = (t, h, w)
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
if self.is_training:
assert tgts, "training must give tgt !"
assert tgts, "Training must give tgt"
self._extract_labels(outputs, tgts)
return outputs
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
@@ -304,7 +345,7 @@ class DataProcessor:
outputs["position_ids"].append([start + i] * 3)
outputs["cur_position"] += len(tokens)
def _add_image(self, img, outputs: Dict) -> None:
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
img.height,
img.width,
@@ -313,6 +354,7 @@ class DataProcessor:
)[1]
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
@@ -330,10 +372,32 @@ class DataProcessor:
input_data_format=ChannelDimension.LAST,
)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(ret["image_grid_thw"])
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict) -> None:
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
_, h, w = meta["thw"]
pos_ids = self._compute_3d_positions(1, h, w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
outputs["images"].append(img)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
frames[0].height,
frames[0].width,
@@ -354,9 +418,14 @@ class DataProcessor:
input_data_format=ChannelDimension.LAST,
)
outputs["images"].append(ret["pixel_values_videos"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values_videos"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(ret["video_grid_thw"])
outputs["image_type_ids"].extend([1] * num_frames)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
@@ -364,6 +433,24 @@ class DataProcessor:
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
t, h, w = meta["thw"]
outputs["images"].append(frames)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[t, h, w]]))
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["image_type_ids"].extend([1] * t)
pos_ids = self._compute_3d_positions(t, h, w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:
input_ids = copy.deepcopy(outputs["input_ids"])
labels = [self.tokenizer.ignored_index] * len(input_ids)
@@ -480,33 +567,22 @@ class DataProcessor:
break
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
def apply_chat_template(self, request, **kwargs):
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
"""
Convert multi-turn messages into ID sequences.
Args:
messages: Either a request dict containing 'messages' field,
or a list of message dicts directly
Returns:
List of token IDs as strings (converted from token objects)
get cache correspond to given hash values
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
req = pickle.dumps(mm_hashes)
socket.send_multipart([b"", req])
_, resp = socket.recv_multipart()
mm_items = pickle.loads(resp)
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
prompt_token_template = self.tokenizer.apply_chat_template(
request,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**kwargs,
)
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
"<|video@placeholder|>", ""
)
request["prompt_tokens"] = prompt_token_template
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}"
)
return token_ids
return mm_items
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
"""
update cache data
"""
req = pickle.dumps((mm_hashes, mm_items))
socket.send_multipart([b"", req])
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")

View File

@@ -46,6 +46,7 @@ class InputPreprocessor:
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
tool_parser: str = None,
enable_processor_cache: bool = False,
) -> None:
self.model_config = model_config
self.model_name_or_path = self.model_config.model
@@ -53,6 +54,7 @@ class InputPreprocessor:
self.limit_mm_per_prompt = limit_mm_per_prompt
self.mm_processor_kwargs = mm_processor_kwargs
self.tool_parser = tool_parser
self.enable_processor_cache = enable_processor_cache
def create_processor(self):
reasoning_parser_obj = None
@@ -104,6 +106,19 @@ class InputPreprocessor:
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
enable_processor_cache=self.enable_processor_cache,
)
elif "PaddleOCRVL" in architecture:
from fastdeploy.input.paddleocr_vl_processor import (
PaddleOCRVLProcessor,
)
self.processor = PaddleOCRVLProcessor(
config=self.model_config,
model_name_or_path=self.model_name_or_path,
limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
)
elif "PaddleOCRVL" in architecture:
from fastdeploy.input.paddleocr_vl_processor import (
@@ -126,5 +141,6 @@ class InputPreprocessor:
limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
enable_processor_cache=self.enable_processor_cache,
)
return self.processor

View File

@@ -15,17 +15,23 @@
# limitations under the License.
"""
from typing import Any, Dict, List, Tuple, Union
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import zmq
from paddleformers.transformers import AutoTokenizer
from PIL import Image
from fastdeploy.engine.request import ImagePosition
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.utils import data_processor_logger
from .image_processor import ImageProcessor
from .process_video import read_frames, sample_frames
from .process_video import sample_frames
class DataProcessor:
@@ -49,8 +55,11 @@ class DataProcessor:
def __init__(
self,
model_path: str,
enable_processor_cache: bool = False,
video_min_frames: int = 4,
video_max_frames: int = 768,
video_target_frames: int = -1,
video_fps: int = -1,
tokens_per_second: int = 2,
tokenizer=None,
**kwargs,
@@ -67,6 +76,8 @@ class DataProcessor:
"""
self.min_frames = video_min_frames
self.max_frames = video_max_frames
self.target_frames = video_target_frames
self.fps = video_fps
# Initialize tokenizer with left padding and fast tokenizer
if tokenizer is None:
@@ -75,6 +86,7 @@ class DataProcessor:
else:
self.tokenizer = tokenizer
self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor
self.enable_processor_cache = enable_processor_cache
# Convolution sizes for patch aggregation
self.spatial_conv_size = self.image_processor.merge_size
@@ -99,41 +111,7 @@ class DataProcessor:
"assistant": "Assistant: ",
}
def _pack_outputs(self, outputs):
"""
Pack and convert all output data into numpy arrays with appropriate types.
Args:
outputs (dict): Dictionary containing model outputs with keys:
- images: List of visual features
- grid_thw: List of spatial dimensions
- image_type_ids: List of content type indicators
- input_ids: List of token IDs
- token_type_ids: List of type identifiers
- position_ids: List of position embeddings
Returns:
dict: Processed outputs with all values converted to numpy arrays
"""
# Process visual outputs - stack if exists or set to None if empty
if not outputs["images"]:
outputs["images"] = None # No images case
outputs["grid_thw"] = None # No spatial dimensions
outputs["image_type_ids"] = None # No type IDs
else:
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
# Convert all outputs to numpy arrays with appropriate types
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
outputs["position_ids"] = np.concatenate(
outputs["position_ids"], axis=1, dtype=np.int64
) # Concatenate position IDs
return outputs
def text2ids(self, text, images=None, videos=None):
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
"""
Convert text with image/video placeholders into model inputs.
@@ -141,6 +119,8 @@ class DataProcessor:
text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
images: List of PIL Images corresponding to image placeholders
videos: List of video data corresponding to video placeholders
image_uuid: List of unique identifiers for each image, used for caching or hashing.
video_uuid: List of unique identifiers for each video, used for caching or hashing.
Returns:
Dict containing:
@@ -161,8 +141,10 @@ class DataProcessor:
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
"fps": [],
"mm_positions": [],
"mm_hashes": [],
}
# Define placeholders and their lengths
@@ -186,23 +168,30 @@ class DataProcessor:
break
if ed == image_pos:
outputs["pic_cnt"] += 1
self._add_image(images[image_idx], outputs)
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid)
else:
self._add_processed_image(image, outputs, uuid)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
if isinstance(item, dict):
frames, meta = self._load_and_process_video(item["video"], item)
uuid = video_uuid[video_idx] if video_uuid else None
if not isinstance(item, tuple):
if isinstance(item, dict):
frames, meta = self._load_and_process_video(item["video"], item)
else:
frames, meta = self._load_and_process_video(item, {})
self._add_video(frames, meta, outputs, uuid)
else:
frames, meta = self._load_and_process_video(item, {})
outputs["video_cnt"] += 1
self._add_video(frames, meta, outputs)
# cached frames are already processed
self._add_processed_video(item, outputs, uuid)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
return self._pack_outputs(outputs)
return outputs
def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
@@ -220,74 +209,84 @@ class DataProcessor:
Dict with same structure as text2ids() output
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
# Parse and validate chat messages
messages = parse_chat_messages(request.get("messages"))
image_message_list = [] # Store visual content messages
mm_items = []
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
# Normalize content to list format
content_items = msg.get("content")
if not isinstance(content_items, list):
content_items = [content_items]
content = msg.get("content")
if not isinstance(content, list):
content = [content]
# Collect all visual content items
for item in content_items:
if isinstance(item, dict) and item.get("type") in ["image", "video"]:
image_message_list.append(item)
for item in content:
if item.get("type") in ["image", "video"]:
mm_items.append(item)
raw_messages = request["messages"]
request["messages"] = messages
missing_hashes, missing_idx = [], []
for idx, item in enumerate(mm_items):
if not item.get("data"):
# raw data not provided, should be retrieved from processor cache
missing_hashes.append(item.get("uuid"))
missing_idx.append(idx)
prompt_token_ids = self.apply_chat_template(request)
if len(prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
request["messages"] = raw_messages
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")
vision_start_index = 0
vision_message_index = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] == self.vision_start_id:
self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs)
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
vision_start_index = i + 1
image_message = image_message_list[vision_message_index]
missing_items = self.get_processor_cache(dealer, missing_hashes)
for idx in range(len(missing_items)):
if not missing_items[idx]:
raise ValueError(f"Missing item {idx} not found in processor cache")
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
if image_message["type"] == "image":
img = image_message.get("image")
if img is None:
continue
outputs["pic_cnt"] += 1
self._add_image(img, outputs)
images, videos = [], []
image_uuid, video_uuid = [], []
for item in mm_items:
if item.get("type") == "image":
images.append(item["data"])
image_uuid.append(item["uuid"])
elif item.get("type") == "video":
videos.append(item["data"])
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
elif image_message["type"] == "video":
video_bytes = image_message.get("video")
if video_bytes is None:
continue
frames, meta = self._load_and_process_video(video_bytes, image_message)
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
outputs["video_cnt"] += 1
self._add_video(frames, meta, outputs)
chat_template_kwargs = request.get("chat_template_kwargs", {})
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**chat_template_kwargs,
)
request["prompt_tokens"] = prompt
vision_message_index += 1
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
self._add_text(prompt_token_ids[vision_start_index:], outputs)
return self._pack_outputs(outputs)
if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx]
meta["thw"] = (t, h, w)
meta["fps"] = outputs["fps"][idx]
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
return outputs
def _add_text(self, tokens, outputs: Dict) -> None:
"""
@@ -312,9 +311,9 @@ class DataProcessor:
outputs["input_ids"].extend(tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
"""
@@ -332,7 +331,7 @@ class DataProcessor:
position = text_index + start_pos
return position
def _add_image(self, img, outputs: Dict) -> None:
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
"""
Add image data to model inputs dictionary.
@@ -349,20 +348,47 @@ class DataProcessor:
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
grid_thw = ret["grid_thw"].tolist()
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(grid_thw)
outputs["image_type_ids"].append(0)
t, h, w = grid_thw
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
def _add_video(self, frames, meta: Dict, outputs: Dict) -> None:
outputs["fps"].append(0)
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // self.image_processor.merge_size**2
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
_, h, w = meta["thw"]
pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["images"].append(img)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)
outputs["fps"].append(0)
def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) -> None:
"""
Add video data to model inputs dictionary.
@@ -380,20 +406,49 @@ class DataProcessor:
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
grid_thw = ret["grid_thw"].tolist()
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(grid_thw)
outputs["image_type_ids"].extend([1] * grid_thw[0])
fps = meta["fps"]
second_per_grid_t = self.temporal_conv_size / fps
t, h, w = grid_thw
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["fps"].append(fps)
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // self.image_processor.merge_size**2
t, h, w = meta["thw"]
outputs["images"].append(frames)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[t, h, w]]))
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["image_type_ids"].extend([1] * t)
fps = meta["fps"]
second_per_grid_t = self.temporal_conv_size / fps
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["fps"].append(fps)
def _compute_vision_positions(
self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
@@ -441,20 +496,20 @@ class DataProcessor:
- frames: Processed video frames as numpy array
- metadata: Updated video metadata dictionary
"""
frames, meta = read_frames(url)
reader, meta, _ = read_video_decord(url, save_to_disk=False)
# Apply frame sampling if fps or target_frames specified
fps = item.get("fps", None)
num_frames = item.get("target_frames", None)
fps = item.get("fps", self.fps)
num_frames = item.get("target_frames", self.target_frames)
if fps is not None or num_frames is not None:
frame_indices = list(range(meta["num_of_frame"]))
if fps > 0 or num_frames > 0:
# Get frame sampling constraints
min_frames = item.get("min_frames", self.min_frames)
max_frames = item.get("max_frames", self.max_frames)
# Sample frames according to specifications
frames = sample_frames(
video=frames,
frame_indices = sample_frames(
frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size
min_frames=min_frames,
max_frames=max_frames,
@@ -464,42 +519,38 @@ class DataProcessor:
)
# Update metadata with new frame count and fps
meta["num_of_frame"] = frames.shape[0]
meta["num_of_frame"] = len(frame_indices)
if fps is not None:
meta["fps"] = fps # Use specified fps
meta["duration"] = frames.shape[0] / fps
meta["duration"] = len(frame_indices) / fps
else:
meta["fps"] = frames.shape[0] / meta["duration"] # Calculate fps from sampled frames
meta["fps"] = len(frame_indices) / meta["duration"] # Calculate fps from sampled frames
frames = []
for idx in frame_indices:
frame = reader[idx].asnumpy()
image = Image.fromarray(frame, "RGB")
frames.append(image)
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
return frames, meta
def apply_chat_template(self, request):
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
"""
Apply chat template to convert messages into token sequence.
Args:
request: Dictionary containing chat messages
Returns:
List of token IDs
Raises:
ValueError: If model doesn't support chat templates
get cache correspond to given hash values
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
req = pickle.dumps(mm_hashes)
socket.send_multipart([b"", req])
_, resp = socket.recv_multipart()
mm_items = pickle.loads(resp)
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
raw_prompt = self.tokenizer.apply_chat_template(
request["messages"],
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
)
prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "")
request["prompt_tokens"] = raw_prompt
return mm_items
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}"
)
return token_ids
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
"""
update cache data
"""
req = pickle.dumps((mm_hashes, mm_items))
socket.send_multipart([b"", req])
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")

View File

@@ -18,50 +18,9 @@ import math
from typing import Optional, Union
import numpy as np
from PIL import Image
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
def read_frames(video_path):
"""
Read and decode video frames from the given path
This function reads a video file and decodes it into individual RGB frames
using decord video reader. It also extracts video metadata including fps,
duration and frame count.
Args:
video_path (str): Path to the video file or bytes object containing video data
Returns:
tuple: A tuple containing:
frames (numpy.ndarray): Array of shape (num_frames, height, width, 3)
containing decoded RGB video frames
meta (dict): Dictionary containing video metadata:
- fps (float): Frames per second
- duration (float): Video duration in seconds
- num_of_frame (int): Total number of frames
- width (int): Frame width in pixels
- height (int): Frame height in pixels
Note:
- The function uses decord library for efficient video reading
- All frames are converted to RGB format regardless of input format
"""
reader, meta, _ = read_video_decord(video_path, save_to_disk=False)
frames = []
for i in range(meta["num_of_frame"]):
frame = reader[i].asnumpy()
image = Image.fromarray(frame, "RGB")
frames.append(image)
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
return frames, meta
def sample_frames(
video: np.ndarray,
frame_factor: int,
min_frames: int,
max_frames: int,
@@ -73,7 +32,6 @@ def sample_frames(
Sample frames from video according to specified criteria.
Args:
video: Input video frames as numpy array
frame_factor: Ensure sampled frames are multiples of this factor
min_frames: Minimum number of frames to sample
max_frames: Maximum number of frames to sample
@@ -89,18 +47,15 @@ def sample_frames(
or if required metadata is missing,
or if requested frames exceed available frames
"""
if fps is not None and num_frames is not None:
if fps > 0 and num_frames > 0:
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
if fps is None and num_frames is None:
return video
total_num_frames = video.shape[0]
total_num_frames = metadata["num_of_frame"]
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is not None:
if num_frames > 0:
num_frames = round(num_frames / frame_factor) * frame_factor
elif fps is not None:
elif fps > 0:
if metadata is None:
raise ValueError(
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
@@ -110,7 +65,6 @@ def sample_frames(
num_frames = total_num_frames / metadata["fps"] * fps
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
num_frames = math.floor(num_frames / frame_factor) * frame_factor
if num_frames > total_num_frames:
raise ValueError(
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
@@ -118,14 +72,11 @@ def sample_frames(
)
# Calculate frame indices based on sampling strategy
if num_frames is not None:
if num_frames > 0:
# Evenly spaced sampling for target frame count
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32)
else:
# Keep all frames if no sampling requested
indices = np.arange(0, total_num_frames).astype(np.int32)
# Apply frame selection
video = video[indices]
return video
return indices

View File

@@ -47,6 +47,7 @@ class QwenVLProcessor(TextProcessor):
mm_processor_kwargs=None,
reasoning_parser_obj=None,
tool_parser_obj=None,
enable_processor_cache=False,
):
"""
Initialize QwenVLProcessor instance.
@@ -65,6 +66,7 @@ class QwenVLProcessor(TextProcessor):
processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)
self.processor = DataProcessor(
model_path=model_name_or_path,
enable_processor_cache=enable_processor_cache,
tokens_per_second=config.vision_config.tokens_per_second,
tokenizer=self.tokenizer,
**processor_kwargs,
@@ -271,7 +273,7 @@ class QwenVLProcessor(TextProcessor):
return request
def append_completion_tokens(self, outputs, completion_token_ids):
def append_completion_tokens(self, multimodal_inputs, completion_token_ids):
"""
Append completion tokens to existing outputs.
@@ -279,19 +281,14 @@ class QwenVLProcessor(TextProcessor):
outputs: Current model outputs
completion_token_ids: completion tokens to append
"""
out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]}
self.processor._add_text(completion_token_ids, out)
outputs["input_ids"] = np.concatenate(
[outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0
)
outputs["token_type_ids"] = np.concatenate(
[outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0
)
outputs["position_ids"] = np.concatenate(
[outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64
)
outputs["cur_position"] = out["cur_position"]
num_tokens = len(completion_token_ids)
multimodal_inputs["input_ids"].extend(completion_token_ids)
multimodal_inputs["token_type_ids"].extend([0] * num_tokens)
pos_ids = self.processor._compute_text_positions(multimodal_inputs["cur_position"], num_tokens)
multimodal_inputs["position_ids"].append(pos_ids)
multimodal_inputs["cur_position"] += num_tokens
def pack_outputs(self, outputs):
"""
@@ -303,7 +300,24 @@ class QwenVLProcessor(TextProcessor):
Returns:
dict: Packed output dictionary with all required fields
"""
if not outputs["images"]:
outputs["images"] = None # No images case
outputs["grid_thw"] = None # No spatial dimensions
outputs["image_type_ids"] = None # No type IDs
else:
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
# Convert all outputs to numpy arrays with appropriate types
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
outputs["position_ids"] = np.concatenate(
outputs["position_ids"], axis=1, dtype=np.int64
) # Concatenate position ID
outputs["image_patch_id"] = self.processor.image_token_id
outputs["video_patch_id"] = self.processor.video_token_id
outputs["position_ids"] = outputs["position_ids"].transpose(1, 0)
return outputs