[Feature] mm support prefix cache (#4134)

* support mm prefix caching

* update code

* fix mm_hashes

* support encoder cache

* add encoder cache

* update code

* update encoder cache

* fix features bug

* fix worker bug

* support processor cache, need to optimize yet

* refactor multimodal data cache

* update code

* update code

* update v1 scheduler

* update code

* update code

* update codestyle

* support turn off processor cache and encoder cache

* update pre-commit

* fix code

* solve review

* update code

* update code

* update test case

* set processor cache in GiB

* update test case

* support mm prefix caching for qwen model

* fix code style check

* update pre-commit

* fix unit test

* fix unit test

* add ci test case

* fix rescheduled bug

* change text_after_process to prompt_tokens

* fix unit test

* fix chat template

* change model path

* [EP] fix adapter bugs (#4572)

* Update expert_service.py

* Update common_engine.py

* Update expert_service.py

* fix v1 hang bug (#4573)

* fix import image_ops error on some platforms (#4559)

* [CLI]Update parameters in bench latecy cli tool and fix collect-env cli tool (#4558)

* add collect-env

* del files

* [Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578)

* add new branch for sot

* reorder

* fix batch bug

* [XPU]Moe uses a new operator (#4585)

* [XPU]Moe uses a new operator

* [XPU]Moe uses a new operator

* update response

* [Feature] Support Paddle-OCR (#4396)

* init

* update code

* fix code style & disable thinking

* adapt for common_engine.update_mm_requests_chunk_size

* use 3d rope

* use flash_attn_unpadded

* opt siglip

* update to be compatible with the latest codebase

* fix typo

* optim OCR performance

* fix bug

* fix bug

* fix bug

* fix bug

* normlize name

* modify xpu rope

* revert logger

* fix bug

* fix bug

* fix bug

* support default_v1

* optim performance

* fix bug

---------

Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>

* [DataProcessor] add reasoning_tokens into usage info (#4520)

* add reasoning_tokens into usage info initial commit

* add unit tests

* modify unit test

* modify and add unit tests

* fix unit test

* move steam usage to processor

* modify processor

* modify test_logprobs

* modify test_logprobs.py

* modify stream reasoning tokens accumulation

* fix unit test

* perf: Optimize task queue communication from engine to worker (#4531)

* perf: Optimize task queue communication from engine to worker

* perf: get_tasks to numpy

* perf: get_tasks remove to_numpy

* fix: request & replace ENV

* remove test_e2w_perf.py

* fix code style

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* Clean up ports after processing results (#4587)

* [CI] Add /re-run command in PR comments to restart failed CI workflows (#4593)

* [Others] api server exits when worker process is dead (#3271)

* [fix] fix terminal hangs when worker process is dead

* [chore] change sleep time of monitor

* [chore] remove redundant comments

* update docs

---------

Co-authored-by: ApplEOFDiscord <wwy640130@163.com>
Co-authored-by: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
Co-authored-by: yinwei <yinwei_hust@163.com>
Co-authored-by: JYChen <zoooo0820@qq.com>
Co-authored-by: qwes5s5 <45442318+qwes5s5@users.noreply.github.com>
Co-authored-by: Ryan <zihaohuang@aliyun.com>
Co-authored-by: yyssys <atyangshuang@foxmail.com>
Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>
Co-authored-by: kxz2002 <115912648+kxz2002@users.noreply.github.com>
Co-authored-by: SunLei <sunlei5788@gmail.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
This commit is contained in:
kevin
2025-10-27 17:39:51 +08:00
committed by GitHub
parent a4fb3d4ff0
commit 8aab4e367f
40 changed files with 1741 additions and 545 deletions

View File

@@ -18,16 +18,20 @@
""" process.py """
import copy
import os
import pickle
from collections import defaultdict
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import zmq
from paddleformers.transformers.image_utils import ChannelDimension
from PIL import Image
from fastdeploy.engine.request import ImagePosition
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.utils import data_processor_logger
from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor
@@ -84,6 +88,7 @@ class DataProcessor:
self,
tokenizer_name: str,
image_preprocessor_name: str,
enable_processor_cache: bool = False,
spatial_conv_size: int = 2,
temporal_conv_size: int = 2,
image_min_pixels: int = 4 * 28 * 28,
@@ -102,6 +107,7 @@ class DataProcessor:
self._load_tokenizer()
self.tokenizer.ignored_index = -100
self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
self.enable_processor_cache = enable_processor_cache
# Convolution sizes for patch aggregation
self.spatial_conv_size = spatial_conv_size
@@ -163,10 +169,18 @@ class DataProcessor:
"""Enable evaluation mode (doesn't produce labels)."""
self.is_training = False
def text2ids(self, text, images=None, videos=None):
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
"""
Convert chat text into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
Args:
text (str): The chat text containing placeholders for images and videos.
images (list, optional): List of images to be processed and inserted at image placeholders.
videos (list, optional): List of videos to be processed and inserted at video placeholders.
image_uuid (list, optional): List of unique identifiers for each image, used for caching or hashing.
video_uuid (list, optional): List of unique identifiers for each video, used for caching or hashing.
Returns:
dict: A dictionary with keys input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels, etc.
"""
outputs = {
@@ -178,8 +192,9 @@ class DataProcessor:
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
"mm_positions": [],
"mm_hashes": [],
}
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
@@ -199,17 +214,27 @@ class DataProcessor:
break
if ed == image_pos:
self._add_image(images[image_idx], outputs)
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid)
else:
# cached images are already processed
self._add_processed_image(image, outputs, uuid)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
if isinstance(item, dict):
frames = self._load_and_process_video(item["video"], item)
uuid = video_uuid[video_idx] if video_uuid else None
if not isinstance(item, tuple):
if isinstance(item, dict):
frames = self._load_and_process_video(item["video"], item)
else:
frames = self._load_and_process_video(item, {})
self._add_video(frames, outputs, uuid)
else:
frames = self._load_and_process_video(item, {})
self._add_video(frames, outputs)
# cached frames are already processed
self._add_processed_video(item, outputs, uuid)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
@@ -223,66 +248,82 @@ class DataProcessor:
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
messages = parse_chat_messages(request.get("messages"))
image_message_list = []
mm_items = []
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
content_items = msg.get("content")
if not isinstance(content_items, list):
content_items = [content_items]
for item in content_items:
if isinstance(item, dict) and item.get("type") in [
"image",
"video",
]:
image_message_list.append(item)
content = msg.get("content")
if not isinstance(content, list):
content = [content]
for item in content:
if item.get("type") in ["image", "video"]:
mm_items.append(item)
missing_hashes, missing_idx = [], []
for idx, item in enumerate(mm_items):
if not item.get("data"):
# raw data not provided, should be retrieved from processor cache
missing_hashes.append(item.get("uuid"))
missing_idx.append(idx)
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
missing_items = self.get_processor_cache(dealer, missing_hashes)
for idx in range(len(missing_items)):
if not missing_items[idx]:
raise ValueError(f"Missing item {idx} not found in processor cache")
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
images, videos = [], []
image_uuid, video_uuid = [], []
for item in mm_items:
if item.get("type") == "image":
images.append(item["data"])
image_uuid.append(item["uuid"])
elif item.get("type") == "video":
videos.append(item["data"])
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
chat_template_kwargs = request.get("chat_template_kwargs", {})
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
if len(prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
image_start_index = 0
image_message_index = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in [
self.image_start_id,
self.video_start_id,
]:
self._add_text(prompt_token_ids[image_start_index : i + 1], outputs)
image_start_index = i + 1
image_message = image_message_list[image_message_index]
if image_message["type"] == "image":
img = image_message.get("image")
if img is None:
continue
outputs["pic_cnt"] += 1
self._add_image(img, outputs)
elif image_message["type"] == "video":
video_bytes = image_message.get("video")
if video_bytes is None:
continue
frames = self._load_and_process_video(video_bytes, image_message)
outputs["video_cnt"] += 1
self._add_video(frames, outputs)
image_message_index += 1
self._add_text(prompt_token_ids[image_start_index:], outputs)
prompt = self.tokenizer.apply_chat_template(
request,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**chat_template_kwargs,
)
request["prompt_tokens"] = prompt
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx][0]
meta["thw"] = (t, h, w)
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
if self.is_training:
assert tgts, "training must give tgt !"
assert tgts, "Training must give tgt"
self._extract_labels(outputs, tgts)
return outputs
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
@@ -304,7 +345,7 @@ class DataProcessor:
outputs["position_ids"].append([start + i] * 3)
outputs["cur_position"] += len(tokens)
def _add_image(self, img, outputs: Dict) -> None:
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
img.height,
img.width,
@@ -313,6 +354,7 @@ class DataProcessor:
)[1]
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
@@ -330,10 +372,32 @@ class DataProcessor:
input_data_format=ChannelDimension.LAST,
)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(ret["image_grid_thw"])
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict) -> None:
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
_, h, w = meta["thw"]
pos_ids = self._compute_3d_positions(1, h, w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
outputs["images"].append(img)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
frames[0].height,
frames[0].width,
@@ -354,9 +418,14 @@ class DataProcessor:
input_data_format=ChannelDimension.LAST,
)
outputs["images"].append(ret["pixel_values_videos"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values_videos"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(ret["video_grid_thw"])
outputs["image_type_ids"].extend([1] * num_frames)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
@@ -364,6 +433,24 @@ class DataProcessor:
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
t, h, w = meta["thw"]
outputs["images"].append(frames)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[t, h, w]]))
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["image_type_ids"].extend([1] * t)
pos_ids = self._compute_3d_positions(t, h, w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:
input_ids = copy.deepcopy(outputs["input_ids"])
labels = [self.tokenizer.ignored_index] * len(input_ids)
@@ -480,33 +567,22 @@ class DataProcessor:
break
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
def apply_chat_template(self, request, **kwargs):
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
"""
Convert multi-turn messages into ID sequences.
Args:
messages: Either a request dict containing 'messages' field,
or a list of message dicts directly
Returns:
List of token IDs as strings (converted from token objects)
get cache correspond to given hash values
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
req = pickle.dumps(mm_hashes)
socket.send_multipart([b"", req])
_, resp = socket.recv_multipart()
mm_items = pickle.loads(resp)
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
prompt_token_template = self.tokenizer.apply_chat_template(
request,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**kwargs,
)
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
"<|video@placeholder|>", ""
)
request["prompt_tokens"] = prompt_token_template
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}"
)
return token_ids
return mm_items
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
"""
update cache data
"""
req = pickle.dumps((mm_hashes, mm_items))
socket.send_multipart([b"", req])
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")