mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] mm support prefix cache (#4134)
* support mm prefix caching * update code * fix mm_hashes * support encoder cache * add encoder cache * update code * update encoder cache * fix features bug * fix worker bug * support processor cache, need to optimize yet * refactor multimodal data cache * update code * update code * update v1 scheduler * update code * update code * update codestyle * support turn off processor cache and encoder cache * update pre-commit * fix code * solve review * update code * update code * update test case * set processor cache in GiB * update test case * support mm prefix caching for qwen model * fix code style check * update pre-commit * fix unit test * fix unit test * add ci test case * fix rescheduled bug * change text_after_process to prompt_tokens * fix unit test * fix chat template * change model path * [EP] fix adapter bugs (#4572) * Update expert_service.py * Update common_engine.py * Update expert_service.py * fix v1 hang bug (#4573) * fix import image_ops error on some platforms (#4559) * [CLI]Update parameters in bench latecy cli tool and fix collect-env cli tool (#4558) * add collect-env * del files * [Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578) * add new branch for sot * reorder * fix batch bug * [XPU]Moe uses a new operator (#4585) * [XPU]Moe uses a new operator * [XPU]Moe uses a new operator * update response * [Feature] Support Paddle-OCR (#4396) * init * update code * fix code style & disable thinking * adapt for common_engine.update_mm_requests_chunk_size * use 3d rope * use flash_attn_unpadded * opt siglip * update to be compatible with the latest codebase * fix typo * optim OCR performance * fix bug * fix bug * fix bug * fix bug * normlize name * modify xpu rope * revert logger * fix bug * fix bug * fix bug * support default_v1 * optim performance * fix bug --------- Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com> * [DataProcessor] add reasoning_tokens into usage info (#4520) * add reasoning_tokens into usage info initial commit * add unit tests * modify unit test * modify and add unit tests * fix unit test * move steam usage to processor * modify processor * modify test_logprobs * modify test_logprobs.py * modify stream reasoning tokens accumulation * fix unit test * perf: Optimize task queue communication from engine to worker (#4531) * perf: Optimize task queue communication from engine to worker * perf: get_tasks to numpy * perf: get_tasks remove to_numpy * fix: request & replace ENV * remove test_e2w_perf.py * fix code style --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> * Clean up ports after processing results (#4587) * [CI] Add /re-run command in PR comments to restart failed CI workflows (#4593) * [Others] api server exits when worker process is dead (#3271) * [fix] fix terminal hangs when worker process is dead * [chore] change sleep time of monitor * [chore] remove redundant comments * update docs --------- Co-authored-by: ApplEOFDiscord <wwy640130@163.com> Co-authored-by: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com> Co-authored-by: yinwei <yinwei_hust@163.com> Co-authored-by: JYChen <zoooo0820@qq.com> Co-authored-by: qwes5s5 <45442318+qwes5s5@users.noreply.github.com> Co-authored-by: Ryan <zihaohuang@aliyun.com> Co-authored-by: yyssys <atyangshuang@foxmail.com> Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com> Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com> Co-authored-by: kxz2002 <115912648+kxz2002@users.noreply.github.com> Co-authored-by: SunLei <sunlei5788@gmail.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,7 @@ from enum import Enum
|
||||
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
|
||||
DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = {
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
|
||||
class CacheMetrics:
|
||||
|
||||
163
fastdeploy/cache_manager/multimodal_cache_manager.py
Normal file
163
fastdeploy/cache_manager/multimodal_cache_manager.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import ImagePosition
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
|
||||
class MultimodalLRUCache(ABC):
|
||||
"""
|
||||
General lru cache for multimodal data
|
||||
"""
|
||||
|
||||
def __init__(self, max_cache_size):
|
||||
self.cache = OrderedDict()
|
||||
self.current_cache_size = 0
|
||||
self.max_cache_size = max_cache_size
|
||||
|
||||
def apply_cache(self, mm_hashes: list[str], mm_items: list[Any]) -> list[str]:
|
||||
"""
|
||||
apply data cache, return evicted data
|
||||
"""
|
||||
assert len(mm_hashes) == len(mm_items), "mm_hashes and mm_items should have same length"
|
||||
|
||||
evicted_hashes = []
|
||||
for idx in range(len(mm_hashes)):
|
||||
if mm_hashes[idx] in self.cache:
|
||||
self.cache.move_to_end(mm_hashes[idx])
|
||||
else:
|
||||
item_size = self.get_item_size(mm_items[idx])
|
||||
if self.current_cache_size + item_size >= self.max_cache_size:
|
||||
if item_size > self.max_cache_size:
|
||||
# cannot be inserted even if we clear all cached data, skip it directly
|
||||
continue
|
||||
needed = item_size - (self.max_cache_size - self.current_cache_size)
|
||||
evicted_hashes.extend(self.evict_cache(needed))
|
||||
self.cache[mm_hashes[idx]] = mm_items[idx]
|
||||
self.current_cache_size += item_size
|
||||
|
||||
return evicted_hashes
|
||||
|
||||
def evict_cache(self, needed: int) -> list[str]:
|
||||
"""
|
||||
evict data cache with needed size
|
||||
"""
|
||||
reduced_size, evicted_hashes = 0, []
|
||||
while reduced_size < needed and len(self.cache):
|
||||
mm_hash, mm_item = self.cache.popitem(last=False)
|
||||
evicted_hashes.append(mm_hash)
|
||||
reduced_size += self.get_item_size(mm_item)
|
||||
self.current_cache_size -= self.get_item_size(mm_item)
|
||||
|
||||
return evicted_hashes
|
||||
|
||||
def get_cache(self, mm_hashes: list[str]) -> list[Any]:
|
||||
"""
|
||||
get cached data correspond to given hash values
|
||||
"""
|
||||
mm_items = []
|
||||
for mm_hash in mm_hashes:
|
||||
if mm_hash not in self.cache:
|
||||
mm_items.append(None)
|
||||
continue
|
||||
mm_items.append(self.cache[mm_hash])
|
||||
|
||||
return mm_items
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
clear all cached data
|
||||
"""
|
||||
evicted_hashes = list(self.cache.keys())
|
||||
self.cache.clear()
|
||||
self.current_cache_size = 0
|
||||
|
||||
return evicted_hashes
|
||||
|
||||
@abstractmethod
|
||||
def get_item_size(self, item: Any) -> int:
|
||||
raise NotImplementedError("Subclasses must define how to get size of an item")
|
||||
|
||||
|
||||
class EncoderCacheManager(MultimodalLRUCache):
|
||||
"""
|
||||
EncoderCacheManager is used to cache image features
|
||||
"""
|
||||
|
||||
def __init__(self, max_encoder_cache):
|
||||
super().__init__(max_encoder_cache)
|
||||
|
||||
def get_item_size(self, item: ImagePosition) -> int:
|
||||
return item.length
|
||||
|
||||
|
||||
class ProcessorCacheManager(MultimodalLRUCache):
|
||||
"""
|
||||
ProcessorCacheManager is used to cache processed data
|
||||
"""
|
||||
|
||||
def __init__(self, max_processor_cache):
|
||||
super().__init__(max_processor_cache)
|
||||
|
||||
self.context = zmq.Context()
|
||||
|
||||
self.router = self.context.socket(zmq.ROUTER)
|
||||
self.router.setsockopt(zmq.SNDHWM, int(envs.FD_ZMQ_SNDHWM))
|
||||
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.router.bind("ipc:///dev/shm/processor_cache.ipc")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router, zmq.POLLIN)
|
||||
|
||||
self.handler_thread = threading.Thread(target=self.cache_request_handler, daemon=True)
|
||||
self.handler_thread.start()
|
||||
|
||||
def get_item_size(self, item: Tuple[np.ndarray, dict]) -> int:
|
||||
return item[0].nbytes
|
||||
|
||||
def cache_request_handler(self):
|
||||
try:
|
||||
while True:
|
||||
events = dict(self.poller.poll())
|
||||
|
||||
if self.router in events:
|
||||
client, _, content = self.router.recv_multipart()
|
||||
req = pickle.loads(content)
|
||||
|
||||
if isinstance(req, tuple):
|
||||
# apply cache request, in format of (mm_hashes, mm_items)
|
||||
self.apply_cache(req[0], req[1])
|
||||
logger.info(f"Apply processor cache of mm_hashes: {req[0]}")
|
||||
else:
|
||||
# get cache request
|
||||
resp = self.get_cache(req)
|
||||
logger.info(f"Get processor cache of mm_hashes: {req}")
|
||||
self.router.send_multipart([client, b"", pickle.dumps(resp)])
|
||||
except Exception as e:
|
||||
logger.error(f"Error happened while handling processor cache request: {e}")
|
||||
@@ -14,8 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import heapq
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
@@ -35,7 +37,7 @@ from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTre
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
|
||||
class PrefixCacheManager:
|
||||
@@ -575,31 +577,18 @@ class PrefixCacheManager:
|
||||
"""
|
||||
try:
|
||||
req_id = task.request_id
|
||||
block_tables = task.block_tables
|
||||
|
||||
last_node, num_cached_tokens = self.cache_info[req_id]
|
||||
if isinstance(task.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = task.prompt_token_ids
|
||||
input_ids = prompt_token_ids + task.output_token_ids
|
||||
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
|
||||
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
|
||||
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
|
||||
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
|
||||
self.leaf_req_map[last_node].remove(req_id)
|
||||
|
||||
with self.request_release_lock:
|
||||
current_time = time.time()
|
||||
leaf_node = self.build_path(
|
||||
req_id=req_id,
|
||||
current_time=current_time,
|
||||
input_ids=input_ids,
|
||||
left_input_ids=left_input_ids,
|
||||
gpu_block_ids=gpu_extra_block_ids,
|
||||
leaf_node = self.mm_build_path(
|
||||
request=task,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
block_size=block_size,
|
||||
last_node=last_node,
|
||||
reverved_dec_block_num=0,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
)
|
||||
self.req_leaf_map[req_id] = leaf_node
|
||||
self.leaf_req_map[leaf_node].add(req_id)
|
||||
@@ -636,10 +625,9 @@ class PrefixCacheManager:
|
||||
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = task.prompt_token_ids
|
||||
input_ids = prompt_token_ids + task.output_token_ids
|
||||
req_id = task.request_id
|
||||
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
|
||||
input_token_num = len(input_ids)
|
||||
input_token_num = len(prompt_token_ids + task.output_token_ids)
|
||||
common_block_ids = []
|
||||
# 1. match block
|
||||
(
|
||||
@@ -649,7 +637,7 @@ class PrefixCacheManager:
|
||||
match_block_node,
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
) = self.match_block(req_id, input_ids, block_size)
|
||||
) = self.mm_match_block(task, block_size)
|
||||
|
||||
# update matched node info
|
||||
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
|
||||
@@ -1145,6 +1133,173 @@ class PrefixCacheManager:
|
||||
"""
|
||||
return hash(tuple(block))
|
||||
|
||||
def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
|
||||
"""
|
||||
Retrieves additional hash keys for block identification.
|
||||
|
||||
Args:
|
||||
request: The input request object containing the data to be processed.
|
||||
start_idx (int): The starting index of the block segment to hash.
|
||||
end_idx (int): The ending index of the block segment to hash.
|
||||
mm_idx: The multimodal index identifier for specialized content handling.
|
||||
|
||||
Returns:
|
||||
mm_idx: next multimodal index
|
||||
hash_keys: A list of additional hash keys
|
||||
"""
|
||||
hash_keys = []
|
||||
mm_inputs = request.multimodal_inputs
|
||||
if (
|
||||
mm_inputs is None
|
||||
or "mm_positions" not in mm_inputs
|
||||
or "mm_hashes" not in mm_inputs
|
||||
or len(mm_inputs["mm_positions"]) == 0
|
||||
):
|
||||
return mm_idx, hash_keys
|
||||
|
||||
assert start_idx < end_idx, f"start_idx {start_idx} >= end_idx {end_idx}"
|
||||
assert (
|
||||
start_idx >= 0 and start_idx < request.num_total_tokens
|
||||
), f"start_idx {start_idx} out of range {request.num_total_tokens}"
|
||||
assert (
|
||||
end_idx >= 0 and end_idx <= request.num_total_tokens
|
||||
), f"end_idx {end_idx} out of range {request.num_total_tokens}"
|
||||
assert len(mm_inputs["mm_positions"]) == len(
|
||||
mm_inputs["mm_hashes"]
|
||||
), f"mm_positions {len(mm_inputs['mm_positions'])} != mm_hashes {len(mm_inputs['mm_hashes'])}"
|
||||
assert mm_idx >= 0 and mm_idx < len(
|
||||
mm_inputs["mm_hashes"]
|
||||
), f"mm_idx {mm_idx} out of range {len(mm_inputs['mm_hashes'])}"
|
||||
|
||||
if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length < start_idx:
|
||||
# non images in current block
|
||||
return mm_idx, hash_keys
|
||||
|
||||
for img_idx in range(mm_idx, len(mm_inputs["mm_positions"])):
|
||||
image_offset = mm_inputs["mm_positions"][img_idx].offset
|
||||
image_length = mm_inputs["mm_positions"][img_idx].length
|
||||
|
||||
if image_offset + image_length < start_idx:
|
||||
# image before block
|
||||
continue
|
||||
elif image_offset >= end_idx:
|
||||
# image after block
|
||||
return img_idx, hash_keys
|
||||
elif image_offset + image_length > end_idx:
|
||||
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
|
||||
return img_idx, hash_keys
|
||||
else:
|
||||
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
|
||||
return len(mm_inputs["mm_positions"]) - 1, hash_keys
|
||||
|
||||
def hash_block_features(self, input_ids, extra_keys: list = []):
|
||||
"""
|
||||
calculate hash value of a block with additional keys
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
extra_keys: Additional keys for block identification
|
||||
"""
|
||||
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
|
||||
|
||||
def mm_match_block(self, request, block_size):
|
||||
"""
|
||||
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
|
||||
|
||||
Args:
|
||||
request: The multimodal request object containing prompt and output token IDs.
|
||||
block_size (int): The size of each token block for matching and processing.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- match_gpu_block_ids (list): List of block IDs matched in GPU cache
|
||||
- match_cpu_block_ids (list): List of block IDs matched in CPU cache
|
||||
- swap_node_ids (list): List of node IDs scheduled for GPU-CPU swapping
|
||||
- current_match_node: The last matched node in the radix tree traversal
|
||||
- gpu_match_token_num (int): Total number of tokens matched in GPU cache
|
||||
- cpu_match_token_num (int): Total number of tokens matched in CPU cache
|
||||
"""
|
||||
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = request.prompt_token_ids
|
||||
input_ids = prompt_token_ids + request.output_token_ids
|
||||
total_token_num = len(input_ids)
|
||||
current_match_node = self.radix_tree_root # 从根节点开始搜
|
||||
match_gpu_block_ids = []
|
||||
match_cpu_block_ids = []
|
||||
match_node_ids = []
|
||||
mm_idx = 0
|
||||
match_token_num = 0
|
||||
cpu_match_token_num = 0
|
||||
gpu_match_token_num = 0
|
||||
swap_node_ids = []
|
||||
matche_nodes = []
|
||||
has_modified_gpu_lru_leaf_heap = False
|
||||
has_modified_cpu_lru_leaf_heap = False
|
||||
|
||||
with self.cache_status_lock:
|
||||
while match_token_num < total_token_num:
|
||||
token_block = input_ids[match_token_num : match_token_num + block_size]
|
||||
token_num = len(token_block)
|
||||
if token_num != block_size:
|
||||
break
|
||||
mm_idx, extra_keys = self.get_block_hash_extra_keys(
|
||||
request=request,
|
||||
start_idx=match_token_num,
|
||||
end_idx=match_token_num + block_size,
|
||||
mm_idx=mm_idx,
|
||||
)
|
||||
hash_value = self.hash_block_features(token_block, extra_keys)
|
||||
if hash_value in current_match_node.children:
|
||||
child = current_match_node.children[hash_value]
|
||||
matche_nodes.append(child)
|
||||
match_node_ids.append(child.node_id)
|
||||
if child in self.gpu_lru_leaf_set:
|
||||
self.gpu_lru_leaf_set.remove(child)
|
||||
self.gpu_lru_leaf_heap.remove(child)
|
||||
has_modified_gpu_lru_leaf_heap = True
|
||||
elif child in self.cpu_lru_leaf_set:
|
||||
self.cpu_lru_leaf_set.remove(child)
|
||||
self.cpu_lru_leaf_heap.remove(child)
|
||||
has_modified_cpu_lru_leaf_heap = True
|
||||
if child.has_in_gpu:
|
||||
match_gpu_block_ids.append(child.block_id)
|
||||
gpu_match_token_num += block_size
|
||||
else:
|
||||
if child.cache_status == CacheStatus.SWAP2CPU:
|
||||
logger.info(
|
||||
f"match_block: req_id {request.request_id} matched node"
|
||||
+ f" {child.node_id} which is being SWAP2CPU"
|
||||
)
|
||||
child.cache_status = CacheStatus.GPU
|
||||
match_gpu_block_ids.append(child.block_id)
|
||||
gpu_match_token_num += block_size
|
||||
elif child.cache_status == CacheStatus.CPU:
|
||||
child.cache_status = CacheStatus.SWAP2GPU
|
||||
match_cpu_block_ids.append(child.block_id)
|
||||
cpu_match_token_num += block_size
|
||||
swap_node_ids.append(child.node_id)
|
||||
match_token_num = match_token_num + block_size
|
||||
current_match_node = child
|
||||
else:
|
||||
break
|
||||
|
||||
if has_modified_gpu_lru_leaf_heap:
|
||||
heapq.heapify(self.gpu_lru_leaf_heap)
|
||||
if has_modified_cpu_lru_leaf_heap:
|
||||
heapq.heapify(self.cpu_lru_leaf_heap)
|
||||
|
||||
logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
|
||||
return (
|
||||
match_gpu_block_ids,
|
||||
match_cpu_block_ids,
|
||||
swap_node_ids,
|
||||
current_match_node,
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
)
|
||||
|
||||
def match_block(self, req_id, input_ids, block_size):
|
||||
"""
|
||||
Args:
|
||||
@@ -1241,6 +1396,86 @@ class PrefixCacheManager:
|
||||
node.req_id_set.add(req_id)
|
||||
node = node.parent
|
||||
|
||||
def mm_build_path(self, request, num_computed_tokens, block_size, last_node, num_cached_tokens):
|
||||
"""
|
||||
Constructs a caching path in radix tree for multimodal requests by processing computed tokens.
|
||||
|
||||
Args:
|
||||
request: The inference request object containing:
|
||||
- prompt_token_ids: Original input tokens (List[int] or np.ndarray)
|
||||
- output_token_ids: Generated tokens (List[int])
|
||||
- mm_positions: Optional image positions for multimodal content
|
||||
num_computed_tokens: Total tokens processed so far (cached + newly computed)
|
||||
block_size: Fixed size of token blocks (must match cache configuration)
|
||||
last_node: The deepest existing BlockNode in the radix tree for this request
|
||||
num_cached_tokens: Number of tokens already cached
|
||||
|
||||
Returns:
|
||||
BlockNode: The new deepest node in the constructed path
|
||||
"""
|
||||
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = request.prompt_token_ids
|
||||
input_ids = prompt_token_ids + request.output_token_ids
|
||||
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
|
||||
if num_cached_tokens == can_cache_computed_tokens:
|
||||
return last_node
|
||||
|
||||
mm_idx = 0
|
||||
node = last_node
|
||||
unique_node_ids = []
|
||||
new_last_node = last_node
|
||||
has_unfilled_block = False
|
||||
current_time = time.time()
|
||||
|
||||
input_hash_value = self.hash_block_features(input_ids)
|
||||
gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy()
|
||||
for i in range(num_cached_tokens, can_cache_computed_tokens, block_size):
|
||||
current_block = input_ids[i : i + block_size]
|
||||
current_block_size = len(current_block) # 最后一个block可能没填满
|
||||
if current_block_size != block_size:
|
||||
has_unfilled_block = True
|
||||
else:
|
||||
mm_idx, extra_keys = self.get_block_hash_extra_keys(
|
||||
request=request,
|
||||
start_idx=i,
|
||||
end_idx=i + block_size,
|
||||
mm_idx=mm_idx,
|
||||
)
|
||||
hash_value = self.hash_block_features(current_block, extra_keys)
|
||||
allocated_block_id = gpu_block_ids.pop(0)
|
||||
node_id = self.node_id_pool.pop()
|
||||
unique_node_ids.append(node_id)
|
||||
new_last_node = BlockNode(
|
||||
node_id,
|
||||
input_ids,
|
||||
input_hash_value,
|
||||
node.depth + 1,
|
||||
allocated_block_id,
|
||||
current_block_size,
|
||||
hash_value,
|
||||
current_time,
|
||||
parent=node,
|
||||
shared_count=1,
|
||||
reverved_dec_block_ids=[],
|
||||
)
|
||||
new_last_node.req_id_set.add(request.request_id)
|
||||
self.node_map[node_id] = new_last_node
|
||||
node.children[hash_value] = new_last_node
|
||||
node = new_last_node
|
||||
|
||||
reverved_dec_block_ids = []
|
||||
if has_unfilled_block is True:
|
||||
reverved_dec_block_ids.append(gpu_block_ids.pop(0))
|
||||
|
||||
if new_last_node == self.radix_tree_root:
|
||||
self.unfilled_req_block_map[request.request_id] = reverved_dec_block_ids
|
||||
else:
|
||||
new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids)
|
||||
logger.info(f"build_path: allocate unique node ids {unique_node_ids} for req_id {request.request_id}")
|
||||
return new_last_node
|
||||
|
||||
def build_path(
|
||||
self,
|
||||
req_id,
|
||||
|
||||
@@ -1137,6 +1137,8 @@ class CacheConfig:
|
||||
enc_dec_block_num (int): Number of encoder-decoder blocks.
|
||||
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1.
|
||||
enable_prefix_caching (bool): Enable prefix caching.
|
||||
max_encoder_cache(int): Maximum number of tokens in the encoder cache.
|
||||
max_processor_cache(int): Maximum number of bytes in the processor cache.
|
||||
"""
|
||||
self.block_size = 64
|
||||
self.gpu_memory_utilization = 0.9
|
||||
@@ -1157,6 +1159,8 @@ class CacheConfig:
|
||||
self.enable_ssd_cache = False
|
||||
self.cache_queue_port = None
|
||||
self.swap_space = None
|
||||
self.max_encoder_cache = None
|
||||
self.max_processor_cache = None
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
@@ -1440,7 +1444,7 @@ class FDConfig:
|
||||
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
|
||||
if current_platform.is_xpu():
|
||||
self.max_prefill_batch = 1
|
||||
if self.model_config is not None and self.model_config.enable_mm:
|
||||
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
|
||||
else:
|
||||
self.max_prefill_batch = self.scheduler_config.max_num_seqs
|
||||
@@ -1500,7 +1504,7 @@ class FDConfig:
|
||||
|
||||
self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs)
|
||||
self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size)
|
||||
if self.model_config is not None and self.model_config.enable_mm:
|
||||
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
if (
|
||||
@@ -1513,6 +1517,20 @@ class FDConfig:
|
||||
else:
|
||||
self.structured_outputs_config.guided_decoding_backend = "xgrammar"
|
||||
|
||||
if self.model_config.enable_mm:
|
||||
if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0:
|
||||
self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens
|
||||
elif self.cache_config.max_encoder_cache != 0:
|
||||
if self.cache_config.max_encoder_cache < self.scheduler_config.max_num_batched_tokens:
|
||||
logger.warning(
|
||||
f"max_encoder_cache{self.cache_config.max_encoder_cache} is less than "
|
||||
f"max_num_batched_tokens{self.scheduler_config.max_num_batched_tokens}, "
|
||||
f"set to max_num_batched_tokens."
|
||||
)
|
||||
self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens
|
||||
else:
|
||||
self.cache_config.max_encoder_cache = 0
|
||||
|
||||
# Adjustment GraphOptConfig
|
||||
if self.load_config is not None and self.load_config.dynamic_load_weight is True:
|
||||
self.graph_opt_config.graph_opt_level = 0
|
||||
|
||||
@@ -123,6 +123,14 @@ class EngineArgs:
|
||||
"""
|
||||
Limitation of numbers of multi-modal data.
|
||||
"""
|
||||
max_encoder_cache: int = -1
|
||||
"""
|
||||
Maximum number of tokens in the encoder cache.
|
||||
"""
|
||||
max_processor_cache: float = -1
|
||||
"""
|
||||
Maximum number of bytes(in GiB) in the processor cache.
|
||||
"""
|
||||
reasoning_parser: str = None
|
||||
"""
|
||||
specifies the reasoning parser to use for extracting reasoning content from the model output
|
||||
@@ -526,6 +534,18 @@ class EngineArgs:
|
||||
type=json.loads,
|
||||
help="Additional keyword arguments for the multi-modal processor.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-encoder-cache",
|
||||
default=EngineArgs.max_encoder_cache,
|
||||
type=int,
|
||||
help="Maximum encoder cache tokens(use 0 to disable).",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-processor-cache",
|
||||
default=EngineArgs.max_processor_cache,
|
||||
type=float,
|
||||
help="Maximum processor cache bytes(use 0 to disable).",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--enable-mm",
|
||||
action=DeprecatedOptionWarning,
|
||||
|
||||
@@ -634,13 +634,9 @@ class EngineService:
|
||||
int(self.resource_manager.available_batch()),
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
if self.cfg.model_config.enable_mm:
|
||||
available_blocks = self.resource_manager.available_block_num()
|
||||
else:
|
||||
available_blocks = self.cfg.cache_config.max_block_num_per_seq
|
||||
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=available_blocks,
|
||||
available_blocks=self.cfg.cache_config.max_block_num_per_seq,
|
||||
block_size=self.cfg.cache_config.block_size,
|
||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||
max_num_batched_tokens=self.cfg.model_config.max_model_len,
|
||||
|
||||
@@ -528,6 +528,7 @@ class LLMEngine:
|
||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
|
||||
f" --ips {ips}"
|
||||
f" --max_encoder_cache {self.cfg.cache_config.max_encoder_cache}"
|
||||
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
|
||||
f" --runner {self.cfg.model_config.runner}"
|
||||
f" --convert {self.cfg.model_config.convert}"
|
||||
|
||||
@@ -46,6 +46,12 @@ class RequestType(Enum):
|
||||
EXTEND = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePosition:
|
||||
offset: int = 0
|
||||
length: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
def __init__(
|
||||
|
||||
@@ -28,10 +28,21 @@ import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
|
||||
from fastdeploy.cache_manager.multimodal_cache_manager import (
|
||||
EncoderCacheManager,
|
||||
ProcessorCacheManager,
|
||||
)
|
||||
from fastdeploy.engine.request import (
|
||||
ImagePosition,
|
||||
Request,
|
||||
RequestOutput,
|
||||
RequestStatus,
|
||||
RequestType,
|
||||
)
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.multimodal.hasher import MultimodalHasher
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
@@ -175,6 +186,15 @@ class ResourceManagerV1(ResourceManager):
|
||||
|
||||
self.need_block_num_map = dict()
|
||||
|
||||
self.encoder_cache = None
|
||||
if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0:
|
||||
self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache)
|
||||
|
||||
self.processor_cache = None
|
||||
if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0:
|
||||
max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024)
|
||||
self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes)
|
||||
|
||||
def allocated_slots(self, request: Request):
|
||||
return len(request.block_tables) * self.config.cache_config.block_size
|
||||
|
||||
@@ -273,6 +293,44 @@ class ResourceManagerV1(ResourceManager):
|
||||
break
|
||||
return can_schedule
|
||||
|
||||
def _update_mm_hashes(self, request):
|
||||
if request.multimodal_inputs is None:
|
||||
return
|
||||
|
||||
inputs = request.multimodal_inputs
|
||||
if (
|
||||
inputs.get("images", None) is not None
|
||||
and inputs.get("image_patch_id", None) is not None
|
||||
and inputs.get("grid_thw", None) is not None
|
||||
and len(inputs["grid_thw"]) != 0
|
||||
):
|
||||
grid_thw = []
|
||||
new_mm_positions, new_mm_hashes = [], []
|
||||
image_st = 0
|
||||
for idx, one in enumerate(inputs["grid_thw"]):
|
||||
t, h, w = one[0], one[1], one[2]
|
||||
if t == 1:
|
||||
grid_thw.append(one)
|
||||
new_mm_positions.append(inputs["mm_positions"][idx])
|
||||
new_mm_hashes.append(inputs["mm_hashes"][idx])
|
||||
image_st += h * w
|
||||
else:
|
||||
grid_thw.extend([[2, h, w]] * (t // 2))
|
||||
token_st = inputs["mm_positions"][idx].offset
|
||||
for _ in range(t // 2):
|
||||
new_mm_positions.append(ImagePosition(token_st, h * w // 4))
|
||||
# videos are split into patches every 2 frames, need to rehash
|
||||
new_mm_hashes.append(
|
||||
MultimodalHasher.hash_features(inputs["images"][image_st : image_st + 2 * h * w])
|
||||
)
|
||||
image_st += 2 * h * w
|
||||
token_st += h * w // 4
|
||||
inputs["mm_positions"] = new_mm_positions
|
||||
inputs["mm_hashes"] = new_mm_hashes
|
||||
else:
|
||||
inputs["mm_positions"] = []
|
||||
inputs["mm_hashes"] = []
|
||||
|
||||
def _get_num_new_tokens(self, request, token_budget):
|
||||
# TODO: set condition to new _get_num_new_tokens
|
||||
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
|
||||
@@ -333,11 +391,12 @@ class ResourceManagerV1(ResourceManager):
|
||||
|
||||
if request.multimodal_img_boundaries is None:
|
||||
grid_thw = []
|
||||
for one in inputs["grid_thw"]:
|
||||
if one[0] == 1:
|
||||
for idx, one in enumerate(inputs["grid_thw"]):
|
||||
t, h, w = one[0], one[1], one[2]
|
||||
if t == 1:
|
||||
grid_thw.append(one)
|
||||
else:
|
||||
grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2))
|
||||
grid_thw.extend([[2, h, w]] * (t // 2))
|
||||
|
||||
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
|
||||
if current_platform.is_xpu():
|
||||
@@ -398,6 +457,11 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1))
|
||||
request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1))
|
||||
|
||||
cur_mm_hashes = inputs["mm_hashes"][request.num_image_start : request.num_image_end]
|
||||
cur_mm_positions = inputs["mm_positions"][request.num_image_start : request.num_image_end]
|
||||
if self.encoder_cache:
|
||||
request.evict_mm_hashes = self.encoder_cache.apply_cache(cur_mm_hashes, cur_mm_positions)
|
||||
|
||||
# Compatible with scenarios without images and videos.
|
||||
return num_new_tokens
|
||||
|
||||
@@ -553,6 +617,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
break
|
||||
request = self.waiting[0]
|
||||
if request.status == RequestStatus.WAITING:
|
||||
self._update_mm_hashes(request)
|
||||
# Enable prefix caching
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
if (
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
@@ -29,6 +28,7 @@ from openai.types.chat import (
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from fastdeploy.multimodal.image import ImageMediaIO
|
||||
@@ -36,6 +36,17 @@ from fastdeploy.multimodal.video import VideoMediaIO
|
||||
from fastdeploy.utils import api_server_logger
|
||||
|
||||
|
||||
class CustomChatCompletionContentPartImageParam(TypedDict, total=False):
|
||||
"""Custom Image URL object"""
|
||||
|
||||
type: Required[Literal["image_url"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
image_url: Optional[ImageURL]
|
||||
|
||||
uuid: Optional[str]
|
||||
|
||||
|
||||
class VideoURL(TypedDict, total=False):
|
||||
"""Video URL object"""
|
||||
|
||||
@@ -46,14 +57,17 @@ class VideoURL(TypedDict, total=False):
|
||||
class CustomChatCompletionContentPartVideoParam(TypedDict, total=False):
|
||||
"""Custom Video URL object"""
|
||||
|
||||
video_url: Required[VideoURL]
|
||||
|
||||
type: Required[Literal["video_url"]]
|
||||
"""The type of the content type."""
|
||||
"""The type of the content part."""
|
||||
|
||||
video_url: Optional[VideoURL]
|
||||
|
||||
uuid: Optional[str]
|
||||
|
||||
|
||||
CustomChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
CustomChatCompletionContentPartImageParam,
|
||||
CustomChatCompletionContentPartVideoParam,
|
||||
]
|
||||
|
||||
@@ -77,7 +91,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam]
|
||||
|
||||
|
||||
class MultiModalPartParser:
|
||||
class MultimodalPartParser:
|
||||
"""Multi Modal Part parser"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -139,32 +153,46 @@ def parse_content_part(mm_parser, part):
|
||||
return part
|
||||
|
||||
if part_type == "image_url":
|
||||
content = part.get("image_url", {}).get("url", None)
|
||||
image = mm_parser.parse_image(content)
|
||||
parsed = deepcopy(part)
|
||||
del parsed["image_url"]["url"]
|
||||
parsed["image"] = image
|
||||
parsed["type"] = "image"
|
||||
return parsed
|
||||
if not part.get("image_url", None) and not part.get("uuid", None):
|
||||
raise ValueError("Both image_url and uuid are missing")
|
||||
|
||||
if part.get("image_url", None):
|
||||
url = part["image_url"]["url"]
|
||||
image = mm_parser.parse_image(url)
|
||||
else:
|
||||
image = None
|
||||
|
||||
parsed = {}
|
||||
parsed["type"] = "image"
|
||||
parsed["data"] = image
|
||||
parsed["uuid"] = part.get("uuid", None)
|
||||
|
||||
return parsed
|
||||
if part_type == "video_url":
|
||||
content = part.get("video_url", {}).get("url", None)
|
||||
video = mm_parser.parse_video(content)
|
||||
parsed = deepcopy(part)
|
||||
del parsed["video_url"]["url"]
|
||||
parsed["video"] = video
|
||||
if not part.get("video_url", None) and not part.get("uuid", None):
|
||||
raise ValueError("Both video_url and uuid are missing")
|
||||
|
||||
if part.get("video_url", None):
|
||||
url = part["video_url"]["url"]
|
||||
video = mm_parser.parse_video(url)
|
||||
else:
|
||||
video = None
|
||||
|
||||
parsed = {}
|
||||
parsed["type"] = "video"
|
||||
parsed["data"] = video
|
||||
parsed["uuid"] = part.get("uuid", None)
|
||||
|
||||
return parsed
|
||||
|
||||
raise ValueError(f"Unknown content part type: {part_type}")
|
||||
|
||||
|
||||
# TODO async
|
||||
# def parse_chat_messages(messages: List[ChatCompletionMessageParam]):
|
||||
def parse_chat_messages(messages):
|
||||
def parse_chat_messages(messages: List[ChatCompletionMessageParam]):
|
||||
"""Parse chat messages to [dict]"""
|
||||
|
||||
mm_parser = MultiModalPartParser()
|
||||
mm_parser = MultimodalPartParser()
|
||||
|
||||
conversation = []
|
||||
for message in messages:
|
||||
|
||||
@@ -68,16 +68,19 @@ class EngineClient:
|
||||
tool_parser=None,
|
||||
enable_prefix_caching=None,
|
||||
splitwise_role=None,
|
||||
max_processor_cache=0,
|
||||
):
|
||||
model_config = ModelConfig({"model": model_name_or_path})
|
||||
self.enable_mm = model_config.enable_mm
|
||||
enable_processor_cache = self.enable_mm and max_processor_cache > 0
|
||||
input_processor = InputPreprocessor(
|
||||
model_config,
|
||||
reasoning_parser,
|
||||
limit_mm_per_prompt,
|
||||
mm_processor_kwargs,
|
||||
tool_parser,
|
||||
enable_processor_cache,
|
||||
)
|
||||
self.enable_mm = model_config.enable_mm
|
||||
self.enable_logprob = enable_logprob
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.data_processor = input_processor.create_processor()
|
||||
|
||||
@@ -193,6 +193,7 @@ async def lifespan(app: FastAPI):
|
||||
tool_parser=args.tool_call_parser,
|
||||
enable_prefix_caching=args.enable_prefix_caching,
|
||||
splitwise_role=args.splitwise_role,
|
||||
max_processor_cache=args.max_processor_cache,
|
||||
)
|
||||
await engine_client.connection_manager.initialize()
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
|
||||
@@ -470,6 +470,8 @@ class CompletionRequest(BaseModel):
|
||||
max_streaming_response_tokens: Optional[int] = None
|
||||
return_token_ids: Optional[bool] = None
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None
|
||||
|
||||
mm_hashes: Optional[list] = None
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_dict_for_infer(self, request_id=None, prompt=None):
|
||||
@@ -527,6 +529,9 @@ class CompletionRequest(BaseModel):
|
||||
if item is not None:
|
||||
req_dict[key] = item
|
||||
|
||||
if self.mm_hashes is not None and len(self.mm_hashes) > 0:
|
||||
req_dict["mm_hashes"] = self.mm_hashes
|
||||
|
||||
return req_dict
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -553,6 +558,9 @@ class CompletionRequest(BaseModel):
|
||||
"('guided_json', 'guided_regex', 'guided_choice', 'guided_grammar')."
|
||||
)
|
||||
|
||||
if data.get("mm_hashes", None):
|
||||
assert isinstance(data["mm_hashes"], list), "`mm_hashes` must be a list."
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -618,6 +626,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
max_streaming_response_tokens: Optional[int] = None
|
||||
disable_chat_template: Optional[bool] = False
|
||||
|
||||
mm_hashes: Optional[list] = None
|
||||
completion_token_ids: Optional[List[int]] = None
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
@@ -694,6 +704,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
if item is not None:
|
||||
req_dict[key] = item
|
||||
|
||||
if self.mm_hashes is not None and len(self.mm_hashes) > 0:
|
||||
req_dict["mm_hashes"] = self.mm_hashes
|
||||
|
||||
return req_dict
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -721,6 +734,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
"('guided_json', 'guided_regex', 'guided_choice', 'guided_grammar', 'structural_tag')."
|
||||
)
|
||||
|
||||
if data.get("mm_hashes", None):
|
||||
assert isinstance(data["mm_hashes"], list), "`mm_hashes` must be a list."
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -37,6 +37,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
mm_processor_kwargs=None,
|
||||
reasoning_parser_obj=None,
|
||||
tool_parser_obj=None,
|
||||
enable_processor_cache=False,
|
||||
):
|
||||
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
|
||||
tokenizer_path = model_name_or_path
|
||||
@@ -46,6 +47,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
self.ernie4_5_processor = DataProcessor(
|
||||
tokenizer_name=tokenizer_path,
|
||||
image_preprocessor_name=preprocessor_path,
|
||||
enable_processor_cache=enable_processor_cache,
|
||||
**processor_kwargs,
|
||||
)
|
||||
self.ernie4_5_processor.eval()
|
||||
|
||||
@@ -18,16 +18,20 @@
|
||||
""" process.py """
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
from paddleformers.transformers.image_utils import ChannelDimension
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.engine.request import ImagePosition
|
||||
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
||||
from fastdeploy.multimodal.hasher import MultimodalHasher
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor
|
||||
@@ -84,6 +88,7 @@ class DataProcessor:
|
||||
self,
|
||||
tokenizer_name: str,
|
||||
image_preprocessor_name: str,
|
||||
enable_processor_cache: bool = False,
|
||||
spatial_conv_size: int = 2,
|
||||
temporal_conv_size: int = 2,
|
||||
image_min_pixels: int = 4 * 28 * 28,
|
||||
@@ -102,6 +107,7 @@ class DataProcessor:
|
||||
self._load_tokenizer()
|
||||
self.tokenizer.ignored_index = -100
|
||||
self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
|
||||
self.enable_processor_cache = enable_processor_cache
|
||||
|
||||
# Convolution sizes for patch aggregation
|
||||
self.spatial_conv_size = spatial_conv_size
|
||||
@@ -163,10 +169,18 @@ class DataProcessor:
|
||||
"""Enable evaluation mode (doesn't produce labels)."""
|
||||
self.is_training = False
|
||||
|
||||
def text2ids(self, text, images=None, videos=None):
|
||||
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
|
||||
"""
|
||||
Convert chat text into model inputs.
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
|
||||
Args:
|
||||
text (str): The chat text containing placeholders for images and videos.
|
||||
images (list, optional): List of images to be processed and inserted at image placeholders.
|
||||
videos (list, optional): List of videos to be processed and inserted at video placeholders.
|
||||
image_uuid (list, optional): List of unique identifiers for each image, used for caching or hashing.
|
||||
video_uuid (list, optional): List of unique identifiers for each video, used for caching or hashing.
|
||||
Returns:
|
||||
dict: A dictionary with keys input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels, etc.
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
@@ -178,8 +192,9 @@ class DataProcessor:
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
"mm_positions": [],
|
||||
"mm_hashes": [],
|
||||
}
|
||||
|
||||
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
|
||||
@@ -199,17 +214,27 @@ class DataProcessor:
|
||||
break
|
||||
|
||||
if ed == image_pos:
|
||||
self._add_image(images[image_idx], outputs)
|
||||
image = images[image_idx]
|
||||
uuid = image_uuid[image_idx] if image_uuid else None
|
||||
if not isinstance(image, tuple):
|
||||
self._add_image(image, outputs, uuid)
|
||||
else:
|
||||
# cached images are already processed
|
||||
self._add_processed_image(image, outputs, uuid)
|
||||
image_idx += 1
|
||||
st = ed + IMAGE_PLACEHOLDER_LEN
|
||||
else:
|
||||
item = videos[video_idx]
|
||||
if isinstance(item, dict):
|
||||
frames = self._load_and_process_video(item["video"], item)
|
||||
uuid = video_uuid[video_idx] if video_uuid else None
|
||||
if not isinstance(item, tuple):
|
||||
if isinstance(item, dict):
|
||||
frames = self._load_and_process_video(item["video"], item)
|
||||
else:
|
||||
frames = self._load_and_process_video(item, {})
|
||||
self._add_video(frames, outputs, uuid)
|
||||
else:
|
||||
frames = self._load_and_process_video(item, {})
|
||||
|
||||
self._add_video(frames, outputs)
|
||||
# cached frames are already processed
|
||||
self._add_processed_video(item, outputs, uuid)
|
||||
video_idx += 1
|
||||
st = ed + VIDEO_PLACEHOLDER_LEN
|
||||
|
||||
@@ -223,66 +248,82 @@ class DataProcessor:
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
}
|
||||
|
||||
messages = parse_chat_messages(request.get("messages"))
|
||||
image_message_list = []
|
||||
mm_items = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
assert role in self.role_prefixes, f"Unsupported role: {role}"
|
||||
content_items = msg.get("content")
|
||||
if not isinstance(content_items, list):
|
||||
content_items = [content_items]
|
||||
for item in content_items:
|
||||
if isinstance(item, dict) and item.get("type") in [
|
||||
"image",
|
||||
"video",
|
||||
]:
|
||||
image_message_list.append(item)
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
for item in content:
|
||||
if item.get("type") in ["image", "video"]:
|
||||
mm_items.append(item)
|
||||
|
||||
missing_hashes, missing_idx = [], []
|
||||
for idx, item in enumerate(mm_items):
|
||||
if not item.get("data"):
|
||||
# raw data not provided, should be retrieved from processor cache
|
||||
missing_hashes.append(item.get("uuid"))
|
||||
missing_idx.append(idx)
|
||||
|
||||
if len(missing_hashes) > 0 and not self.enable_processor_cache:
|
||||
raise ValueError("Missing items cannot be retrieved without processor cache.")
|
||||
|
||||
if self.enable_processor_cache:
|
||||
context = zmq.Context()
|
||||
dealer = context.socket(zmq.DEALER)
|
||||
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
|
||||
|
||||
missing_items = self.get_processor_cache(dealer, missing_hashes)
|
||||
for idx in range(len(missing_items)):
|
||||
if not missing_items[idx]:
|
||||
raise ValueError(f"Missing item {idx} not found in processor cache")
|
||||
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
|
||||
|
||||
images, videos = [], []
|
||||
image_uuid, video_uuid = [], []
|
||||
for item in mm_items:
|
||||
if item.get("type") == "image":
|
||||
images.append(item["data"])
|
||||
image_uuid.append(item["uuid"])
|
||||
elif item.get("type") == "video":
|
||||
videos.append(item["data"])
|
||||
video_uuid.append(item["uuid"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
|
||||
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat template.")
|
||||
|
||||
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
|
||||
if len(prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
image_start_index = 0
|
||||
image_message_index = 0
|
||||
for i in range(len(prompt_token_ids)):
|
||||
if prompt_token_ids[i] in [
|
||||
self.image_start_id,
|
||||
self.video_start_id,
|
||||
]:
|
||||
self._add_text(prompt_token_ids[image_start_index : i + 1], outputs)
|
||||
image_start_index = i + 1
|
||||
image_message = image_message_list[image_message_index]
|
||||
if image_message["type"] == "image":
|
||||
img = image_message.get("image")
|
||||
if img is None:
|
||||
continue
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(img, outputs)
|
||||
elif image_message["type"] == "video":
|
||||
video_bytes = image_message.get("video")
|
||||
if video_bytes is None:
|
||||
continue
|
||||
frames = self._load_and_process_video(video_bytes, image_message)
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, outputs)
|
||||
image_message_index += 1
|
||||
self._add_text(prompt_token_ids[image_start_index:], outputs)
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
request,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
request["prompt_tokens"] = prompt
|
||||
|
||||
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
|
||||
|
||||
if self.enable_processor_cache:
|
||||
missing_idx = set(missing_idx)
|
||||
hashes_to_cache, items_to_cache = [], []
|
||||
for idx in range(len(mm_items)):
|
||||
if idx in missing_idx:
|
||||
continue
|
||||
meta = {}
|
||||
t, h, w = outputs["grid_thw"][idx][0]
|
||||
meta["thw"] = (t, h, w)
|
||||
hashes_to_cache.append(outputs["mm_hashes"][idx])
|
||||
items_to_cache.append((outputs["images"][idx], meta))
|
||||
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
|
||||
|
||||
if self.is_training:
|
||||
assert tgts, "training must give tgt !"
|
||||
assert tgts, "Training must give tgt"
|
||||
self._extract_labels(outputs, tgts)
|
||||
|
||||
return outputs
|
||||
|
||||
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
|
||||
@@ -304,7 +345,7 @@ class DataProcessor:
|
||||
outputs["position_ids"].append([start + i] * 3)
|
||||
outputs["cur_position"] += len(tokens)
|
||||
|
||||
def _add_image(self, img, outputs: Dict) -> None:
|
||||
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
img.height,
|
||||
img.width,
|
||||
@@ -313,6 +354,7 @@ class DataProcessor:
|
||||
)[1]
|
||||
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
@@ -330,10 +372,32 @@ class DataProcessor:
|
||||
input_data_format=ChannelDimension.LAST,
|
||||
)
|
||||
outputs["images"].append(ret["pixel_values"])
|
||||
if not uuid:
|
||||
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
|
||||
else:
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(ret["image_grid_thw"])
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
def _add_video(self, frames, outputs: Dict) -> None:
|
||||
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
img, meta = img_cache
|
||||
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
_, h, w = meta["thw"]
|
||||
pos_ids = self._compute_3d_positions(1, h, w, outputs["cur_position"])
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
outputs["images"].append(img)
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(np.array([[1, h, w]]))
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
frames[0].height,
|
||||
frames[0].width,
|
||||
@@ -354,9 +418,14 @@ class DataProcessor:
|
||||
input_data_format=ChannelDimension.LAST,
|
||||
)
|
||||
outputs["images"].append(ret["pixel_values_videos"])
|
||||
if not uuid:
|
||||
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values_videos"]))
|
||||
else:
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(ret["video_grid_thw"])
|
||||
outputs["image_type_ids"].extend([1] * num_frames)
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
|
||||
@@ -364,6 +433,24 @@ class DataProcessor:
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
frames, meta = frames_cache
|
||||
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
|
||||
|
||||
t, h, w = meta["thw"]
|
||||
outputs["images"].append(frames)
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(np.array([[t, h, w]]))
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
outputs["image_type_ids"].extend([1] * t)
|
||||
|
||||
pos_ids = self._compute_3d_positions(t, h, w, outputs["cur_position"])
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:
|
||||
input_ids = copy.deepcopy(outputs["input_ids"])
|
||||
labels = [self.tokenizer.ignored_index] * len(input_ids)
|
||||
@@ -480,33 +567,22 @@ class DataProcessor:
|
||||
break
|
||||
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
|
||||
|
||||
def apply_chat_template(self, request, **kwargs):
|
||||
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
|
||||
Args:
|
||||
messages: Either a request dict containing 'messages' field,
|
||||
or a list of message dicts directly
|
||||
|
||||
Returns:
|
||||
List of token IDs as strings (converted from token objects)
|
||||
get cache correspond to given hash values
|
||||
"""
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
req = pickle.dumps(mm_hashes)
|
||||
socket.send_multipart([b"", req])
|
||||
_, resp = socket.recv_multipart()
|
||||
mm_items = pickle.loads(resp)
|
||||
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
|
||||
|
||||
prompt_token_template = self.tokenizer.apply_chat_template(
|
||||
request,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
**kwargs,
|
||||
)
|
||||
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
|
||||
"<|video@placeholder|>", ""
|
||||
)
|
||||
request["prompt_tokens"] = prompt_token_template
|
||||
tokens = self.tokenizer.tokenize(prompt_token_str)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
data_processor_logger.info(
|
||||
f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
return token_ids
|
||||
return mm_items
|
||||
|
||||
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
|
||||
"""
|
||||
update cache data
|
||||
"""
|
||||
req = pickle.dumps((mm_hashes, mm_items))
|
||||
socket.send_multipart([b"", req])
|
||||
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")
|
||||
|
||||
@@ -46,6 +46,7 @@ class InputPreprocessor:
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tool_parser: str = None,
|
||||
enable_processor_cache: bool = False,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.model_name_or_path = self.model_config.model
|
||||
@@ -53,6 +54,7 @@ class InputPreprocessor:
|
||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.tool_parser = tool_parser
|
||||
self.enable_processor_cache = enable_processor_cache
|
||||
|
||||
def create_processor(self):
|
||||
reasoning_parser_obj = None
|
||||
@@ -104,6 +106,19 @@ class InputPreprocessor:
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
tool_parser_obj=tool_parser_obj,
|
||||
enable_processor_cache=self.enable_processor_cache,
|
||||
)
|
||||
elif "PaddleOCRVL" in architecture:
|
||||
from fastdeploy.input.paddleocr_vl_processor import (
|
||||
PaddleOCRVLProcessor,
|
||||
)
|
||||
|
||||
self.processor = PaddleOCRVLProcessor(
|
||||
config=self.model_config,
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
)
|
||||
elif "PaddleOCRVL" in architecture:
|
||||
from fastdeploy.input.paddleocr_vl_processor import (
|
||||
@@ -126,5 +141,6 @@ class InputPreprocessor:
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
enable_processor_cache=self.enable_processor_cache,
|
||||
)
|
||||
return self.processor
|
||||
|
||||
@@ -15,17 +15,23 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
import pickle
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
from paddleformers.transformers import AutoTokenizer
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.engine.request import ImagePosition
|
||||
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
|
||||
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
|
||||
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
||||
from fastdeploy.multimodal.hasher import MultimodalHasher
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
from .image_processor import ImageProcessor
|
||||
from .process_video import read_frames, sample_frames
|
||||
from .process_video import sample_frames
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
@@ -49,8 +55,11 @@ class DataProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
enable_processor_cache: bool = False,
|
||||
video_min_frames: int = 4,
|
||||
video_max_frames: int = 768,
|
||||
video_target_frames: int = -1,
|
||||
video_fps: int = -1,
|
||||
tokens_per_second: int = 2,
|
||||
tokenizer=None,
|
||||
**kwargs,
|
||||
@@ -67,6 +76,8 @@ class DataProcessor:
|
||||
"""
|
||||
self.min_frames = video_min_frames
|
||||
self.max_frames = video_max_frames
|
||||
self.target_frames = video_target_frames
|
||||
self.fps = video_fps
|
||||
|
||||
# Initialize tokenizer with left padding and fast tokenizer
|
||||
if tokenizer is None:
|
||||
@@ -75,6 +86,7 @@ class DataProcessor:
|
||||
else:
|
||||
self.tokenizer = tokenizer
|
||||
self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor
|
||||
self.enable_processor_cache = enable_processor_cache
|
||||
|
||||
# Convolution sizes for patch aggregation
|
||||
self.spatial_conv_size = self.image_processor.merge_size
|
||||
@@ -99,41 +111,7 @@ class DataProcessor:
|
||||
"assistant": "Assistant: ",
|
||||
}
|
||||
|
||||
def _pack_outputs(self, outputs):
|
||||
"""
|
||||
Pack and convert all output data into numpy arrays with appropriate types.
|
||||
|
||||
Args:
|
||||
outputs (dict): Dictionary containing model outputs with keys:
|
||||
- images: List of visual features
|
||||
- grid_thw: List of spatial dimensions
|
||||
- image_type_ids: List of content type indicators
|
||||
- input_ids: List of token IDs
|
||||
- token_type_ids: List of type identifiers
|
||||
- position_ids: List of position embeddings
|
||||
|
||||
Returns:
|
||||
dict: Processed outputs with all values converted to numpy arrays
|
||||
"""
|
||||
# Process visual outputs - stack if exists or set to None if empty
|
||||
if not outputs["images"]:
|
||||
outputs["images"] = None # No images case
|
||||
outputs["grid_thw"] = None # No spatial dimensions
|
||||
outputs["image_type_ids"] = None # No type IDs
|
||||
else:
|
||||
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
|
||||
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
|
||||
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
|
||||
|
||||
# Convert all outputs to numpy arrays with appropriate types
|
||||
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
|
||||
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
|
||||
outputs["position_ids"] = np.concatenate(
|
||||
outputs["position_ids"], axis=1, dtype=np.int64
|
||||
) # Concatenate position IDs
|
||||
return outputs
|
||||
|
||||
def text2ids(self, text, images=None, videos=None):
|
||||
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
|
||||
"""
|
||||
Convert text with image/video placeholders into model inputs.
|
||||
|
||||
@@ -141,6 +119,8 @@ class DataProcessor:
|
||||
text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
|
||||
images: List of PIL Images corresponding to image placeholders
|
||||
videos: List of video data corresponding to video placeholders
|
||||
image_uuid: List of unique identifiers for each image, used for caching or hashing.
|
||||
video_uuid: List of unique identifiers for each video, used for caching or hashing.
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
@@ -161,8 +141,10 @@ class DataProcessor:
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
"fps": [],
|
||||
"mm_positions": [],
|
||||
"mm_hashes": [],
|
||||
}
|
||||
|
||||
# Define placeholders and their lengths
|
||||
@@ -186,23 +168,30 @@ class DataProcessor:
|
||||
break
|
||||
|
||||
if ed == image_pos:
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(images[image_idx], outputs)
|
||||
image = images[image_idx]
|
||||
uuid = image_uuid[image_idx] if image_uuid else None
|
||||
if not isinstance(image, tuple):
|
||||
self._add_image(image, outputs, uuid)
|
||||
else:
|
||||
self._add_processed_image(image, outputs, uuid)
|
||||
image_idx += 1
|
||||
st = ed + IMAGE_PLACEHOLDER_LEN
|
||||
else:
|
||||
item = videos[video_idx]
|
||||
if isinstance(item, dict):
|
||||
frames, meta = self._load_and_process_video(item["video"], item)
|
||||
uuid = video_uuid[video_idx] if video_uuid else None
|
||||
if not isinstance(item, tuple):
|
||||
if isinstance(item, dict):
|
||||
frames, meta = self._load_and_process_video(item["video"], item)
|
||||
else:
|
||||
frames, meta = self._load_and_process_video(item, {})
|
||||
self._add_video(frames, meta, outputs, uuid)
|
||||
else:
|
||||
frames, meta = self._load_and_process_video(item, {})
|
||||
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, meta, outputs)
|
||||
# cached frames are already processed
|
||||
self._add_processed_video(item, outputs, uuid)
|
||||
video_idx += 1
|
||||
st = ed + VIDEO_PLACEHOLDER_LEN
|
||||
|
||||
return self._pack_outputs(outputs)
|
||||
return outputs
|
||||
|
||||
def request2ids(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
@@ -220,74 +209,84 @@ class DataProcessor:
|
||||
Dict with same structure as text2ids() output
|
||||
"""
|
||||
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"pic_cnt": 0,
|
||||
"video_cnt": 0,
|
||||
}
|
||||
|
||||
# Parse and validate chat messages
|
||||
messages = parse_chat_messages(request.get("messages"))
|
||||
image_message_list = [] # Store visual content messages
|
||||
|
||||
mm_items = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
assert role in self.role_prefixes, f"Unsupported role: {role}"
|
||||
|
||||
# Normalize content to list format
|
||||
content_items = msg.get("content")
|
||||
if not isinstance(content_items, list):
|
||||
content_items = [content_items]
|
||||
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
# Collect all visual content items
|
||||
for item in content_items:
|
||||
if isinstance(item, dict) and item.get("type") in ["image", "video"]:
|
||||
image_message_list.append(item)
|
||||
for item in content:
|
||||
if item.get("type") in ["image", "video"]:
|
||||
mm_items.append(item)
|
||||
|
||||
raw_messages = request["messages"]
|
||||
request["messages"] = messages
|
||||
missing_hashes, missing_idx = [], []
|
||||
for idx, item in enumerate(mm_items):
|
||||
if not item.get("data"):
|
||||
# raw data not provided, should be retrieved from processor cache
|
||||
missing_hashes.append(item.get("uuid"))
|
||||
missing_idx.append(idx)
|
||||
|
||||
prompt_token_ids = self.apply_chat_template(request)
|
||||
if len(prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
request["messages"] = raw_messages
|
||||
if len(missing_hashes) > 0 and not self.enable_processor_cache:
|
||||
raise ValueError("Missing items cannot be retrieved without processor cache.")
|
||||
|
||||
vision_start_index = 0
|
||||
vision_message_index = 0
|
||||
for i in range(len(prompt_token_ids)):
|
||||
if prompt_token_ids[i] == self.vision_start_id:
|
||||
self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs)
|
||||
if self.enable_processor_cache:
|
||||
context = zmq.Context()
|
||||
dealer = context.socket(zmq.DEALER)
|
||||
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
|
||||
|
||||
vision_start_index = i + 1
|
||||
image_message = image_message_list[vision_message_index]
|
||||
missing_items = self.get_processor_cache(dealer, missing_hashes)
|
||||
for idx in range(len(missing_items)):
|
||||
if not missing_items[idx]:
|
||||
raise ValueError(f"Missing item {idx} not found in processor cache")
|
||||
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
|
||||
|
||||
if image_message["type"] == "image":
|
||||
img = image_message.get("image")
|
||||
if img is None:
|
||||
continue
|
||||
outputs["pic_cnt"] += 1
|
||||
self._add_image(img, outputs)
|
||||
images, videos = [], []
|
||||
image_uuid, video_uuid = [], []
|
||||
for item in mm_items:
|
||||
if item.get("type") == "image":
|
||||
images.append(item["data"])
|
||||
image_uuid.append(item["uuid"])
|
||||
elif item.get("type") == "video":
|
||||
videos.append(item["data"])
|
||||
video_uuid.append(item["uuid"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
|
||||
|
||||
elif image_message["type"] == "video":
|
||||
video_bytes = image_message.get("video")
|
||||
if video_bytes is None:
|
||||
continue
|
||||
frames, meta = self._load_and_process_video(video_bytes, image_message)
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat template.")
|
||||
|
||||
outputs["video_cnt"] += 1
|
||||
self._add_video(frames, meta, outputs)
|
||||
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
request["prompt_tokens"] = prompt
|
||||
|
||||
vision_message_index += 1
|
||||
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
|
||||
|
||||
self._add_text(prompt_token_ids[vision_start_index:], outputs)
|
||||
return self._pack_outputs(outputs)
|
||||
if self.enable_processor_cache:
|
||||
missing_idx = set(missing_idx)
|
||||
hashes_to_cache, items_to_cache = [], []
|
||||
for idx in range(len(mm_items)):
|
||||
if idx in missing_idx:
|
||||
continue
|
||||
meta = {}
|
||||
t, h, w = outputs["grid_thw"][idx]
|
||||
meta["thw"] = (t, h, w)
|
||||
meta["fps"] = outputs["fps"][idx]
|
||||
hashes_to_cache.append(outputs["mm_hashes"][idx])
|
||||
items_to_cache.append((outputs["images"][idx], meta))
|
||||
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
|
||||
|
||||
return outputs
|
||||
|
||||
def _add_text(self, tokens, outputs: Dict) -> None:
|
||||
"""
|
||||
@@ -312,9 +311,9 @@ class DataProcessor:
|
||||
outputs["input_ids"].extend(tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
|
||||
|
||||
position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
|
||||
"""
|
||||
@@ -332,7 +331,7 @@ class DataProcessor:
|
||||
position = text_index + start_pos
|
||||
return position
|
||||
|
||||
def _add_image(self, img, outputs: Dict) -> None:
|
||||
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
Add image data to model inputs dictionary.
|
||||
|
||||
@@ -349,20 +348,47 @@ class DataProcessor:
|
||||
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
|
||||
grid_thw = ret["grid_thw"].tolist()
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
outputs["images"].append(ret["pixel_values"])
|
||||
if not uuid:
|
||||
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
|
||||
else:
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(grid_thw)
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
t, h, w = grid_thw
|
||||
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
|
||||
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
|
||||
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
def _add_video(self, frames, meta: Dict, outputs: Dict) -> None:
|
||||
outputs["fps"].append(0)
|
||||
|
||||
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
img, meta = img_cache
|
||||
num_tokens = img.shape[0] // self.image_processor.merge_size**2
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
|
||||
|
||||
_, h, w = meta["thw"]
|
||||
pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0)
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
outputs["images"].append(img)
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(np.array([[1, h, w]]))
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
outputs["fps"].append(0)
|
||||
|
||||
def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
Add video data to model inputs dictionary.
|
||||
|
||||
@@ -380,20 +406,49 @@ class DataProcessor:
|
||||
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
|
||||
grid_thw = ret["grid_thw"].tolist()
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
|
||||
outputs["images"].append(ret["pixel_values"])
|
||||
if not uuid:
|
||||
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
|
||||
else:
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(grid_thw)
|
||||
outputs["image_type_ids"].extend([1] * grid_thw[0])
|
||||
|
||||
fps = meta["fps"]
|
||||
second_per_grid_t = self.temporal_conv_size / fps
|
||||
t, h, w = grid_thw
|
||||
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
||||
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
||||
|
||||
outputs["position_ids"].append(position_ids)
|
||||
outputs["cur_position"] = position_ids.max() + 1
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
outputs["fps"].append(fps)
|
||||
|
||||
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
frames, meta = frames_cache
|
||||
num_tokens = frames.shape[0] // self.image_processor.merge_size**2
|
||||
|
||||
t, h, w = meta["thw"]
|
||||
outputs["images"].append(frames)
|
||||
outputs["mm_hashes"].append(uuid)
|
||||
outputs["grid_thw"].append(np.array([[t, h, w]]))
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
|
||||
outputs["image_type_ids"].extend([1] * t)
|
||||
|
||||
fps = meta["fps"]
|
||||
second_per_grid_t = self.temporal_conv_size / fps
|
||||
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
|
||||
outputs["position_ids"].append(pos_ids)
|
||||
outputs["cur_position"] = pos_ids.max() + 1
|
||||
|
||||
outputs["fps"].append(fps)
|
||||
|
||||
def _compute_vision_positions(
|
||||
self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
|
||||
@@ -441,20 +496,20 @@ class DataProcessor:
|
||||
- frames: Processed video frames as numpy array
|
||||
- metadata: Updated video metadata dictionary
|
||||
"""
|
||||
frames, meta = read_frames(url)
|
||||
reader, meta, _ = read_video_decord(url, save_to_disk=False)
|
||||
|
||||
# Apply frame sampling if fps or target_frames specified
|
||||
fps = item.get("fps", None)
|
||||
num_frames = item.get("target_frames", None)
|
||||
fps = item.get("fps", self.fps)
|
||||
num_frames = item.get("target_frames", self.target_frames)
|
||||
|
||||
if fps is not None or num_frames is not None:
|
||||
frame_indices = list(range(meta["num_of_frame"]))
|
||||
if fps > 0 or num_frames > 0:
|
||||
# Get frame sampling constraints
|
||||
min_frames = item.get("min_frames", self.min_frames)
|
||||
max_frames = item.get("max_frames", self.max_frames)
|
||||
|
||||
# Sample frames according to specifications
|
||||
frames = sample_frames(
|
||||
video=frames,
|
||||
frame_indices = sample_frames(
|
||||
frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size
|
||||
min_frames=min_frames,
|
||||
max_frames=max_frames,
|
||||
@@ -464,42 +519,38 @@ class DataProcessor:
|
||||
)
|
||||
|
||||
# Update metadata with new frame count and fps
|
||||
meta["num_of_frame"] = frames.shape[0]
|
||||
meta["num_of_frame"] = len(frame_indices)
|
||||
if fps is not None:
|
||||
meta["fps"] = fps # Use specified fps
|
||||
meta["duration"] = frames.shape[0] / fps
|
||||
meta["duration"] = len(frame_indices) / fps
|
||||
else:
|
||||
meta["fps"] = frames.shape[0] / meta["duration"] # Calculate fps from sampled frames
|
||||
meta["fps"] = len(frame_indices) / meta["duration"] # Calculate fps from sampled frames
|
||||
|
||||
frames = []
|
||||
for idx in frame_indices:
|
||||
frame = reader[idx].asnumpy()
|
||||
image = Image.fromarray(frame, "RGB")
|
||||
frames.append(image)
|
||||
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
|
||||
|
||||
return frames, meta
|
||||
|
||||
def apply_chat_template(self, request):
|
||||
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
|
||||
"""
|
||||
Apply chat template to convert messages into token sequence.
|
||||
|
||||
Args:
|
||||
request: Dictionary containing chat messages
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
|
||||
Raises:
|
||||
ValueError: If model doesn't support chat templates
|
||||
get cache correspond to given hash values
|
||||
"""
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
req = pickle.dumps(mm_hashes)
|
||||
socket.send_multipart([b"", req])
|
||||
_, resp = socket.recv_multipart()
|
||||
mm_items = pickle.loads(resp)
|
||||
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
|
||||
|
||||
raw_prompt = self.tokenizer.apply_chat_template(
|
||||
request["messages"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.get("add_generation_prompt", True),
|
||||
)
|
||||
prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "")
|
||||
request["prompt_tokens"] = raw_prompt
|
||||
return mm_items
|
||||
|
||||
tokens = self.tokenizer.tokenize(prompt_token_str)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
data_processor_logger.info(
|
||||
f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
return token_ids
|
||||
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
|
||||
"""
|
||||
update cache data
|
||||
"""
|
||||
req = pickle.dumps((mm_hashes, mm_items))
|
||||
socket.send_multipart([b"", req])
|
||||
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")
|
||||
|
||||
@@ -18,50 +18,9 @@ import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
|
||||
|
||||
|
||||
def read_frames(video_path):
|
||||
"""
|
||||
Read and decode video frames from the given path
|
||||
|
||||
This function reads a video file and decodes it into individual RGB frames
|
||||
using decord video reader. It also extracts video metadata including fps,
|
||||
duration and frame count.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to the video file or bytes object containing video data
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
frames (numpy.ndarray): Array of shape (num_frames, height, width, 3)
|
||||
containing decoded RGB video frames
|
||||
meta (dict): Dictionary containing video metadata:
|
||||
- fps (float): Frames per second
|
||||
- duration (float): Video duration in seconds
|
||||
- num_of_frame (int): Total number of frames
|
||||
- width (int): Frame width in pixels
|
||||
- height (int): Frame height in pixels
|
||||
|
||||
Note:
|
||||
- The function uses decord library for efficient video reading
|
||||
- All frames are converted to RGB format regardless of input format
|
||||
"""
|
||||
reader, meta, _ = read_video_decord(video_path, save_to_disk=False)
|
||||
|
||||
frames = []
|
||||
for i in range(meta["num_of_frame"]):
|
||||
frame = reader[i].asnumpy()
|
||||
image = Image.fromarray(frame, "RGB")
|
||||
frames.append(image)
|
||||
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
|
||||
return frames, meta
|
||||
|
||||
|
||||
def sample_frames(
|
||||
video: np.ndarray,
|
||||
frame_factor: int,
|
||||
min_frames: int,
|
||||
max_frames: int,
|
||||
@@ -73,7 +32,6 @@ def sample_frames(
|
||||
Sample frames from video according to specified criteria.
|
||||
|
||||
Args:
|
||||
video: Input video frames as numpy array
|
||||
frame_factor: Ensure sampled frames are multiples of this factor
|
||||
min_frames: Minimum number of frames to sample
|
||||
max_frames: Maximum number of frames to sample
|
||||
@@ -89,18 +47,15 @@ def sample_frames(
|
||||
or if required metadata is missing,
|
||||
or if requested frames exceed available frames
|
||||
"""
|
||||
if fps is not None and num_frames is not None:
|
||||
if fps > 0 and num_frames > 0:
|
||||
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
||||
|
||||
if fps is None and num_frames is None:
|
||||
return video
|
||||
|
||||
total_num_frames = video.shape[0]
|
||||
total_num_frames = metadata["num_of_frame"]
|
||||
|
||||
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||
if num_frames is not None:
|
||||
if num_frames > 0:
|
||||
num_frames = round(num_frames / frame_factor) * frame_factor
|
||||
elif fps is not None:
|
||||
elif fps > 0:
|
||||
if metadata is None:
|
||||
raise ValueError(
|
||||
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
|
||||
@@ -110,7 +65,6 @@ def sample_frames(
|
||||
num_frames = total_num_frames / metadata["fps"] * fps
|
||||
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
|
||||
num_frames = math.floor(num_frames / frame_factor) * frame_factor
|
||||
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
|
||||
@@ -118,14 +72,11 @@ def sample_frames(
|
||||
)
|
||||
|
||||
# Calculate frame indices based on sampling strategy
|
||||
if num_frames is not None:
|
||||
if num_frames > 0:
|
||||
# Evenly spaced sampling for target frame count
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32)
|
||||
else:
|
||||
# Keep all frames if no sampling requested
|
||||
indices = np.arange(0, total_num_frames).astype(np.int32)
|
||||
|
||||
# Apply frame selection
|
||||
video = video[indices]
|
||||
|
||||
return video
|
||||
return indices
|
||||
|
||||
@@ -47,6 +47,7 @@ class QwenVLProcessor(TextProcessor):
|
||||
mm_processor_kwargs=None,
|
||||
reasoning_parser_obj=None,
|
||||
tool_parser_obj=None,
|
||||
enable_processor_cache=False,
|
||||
):
|
||||
"""
|
||||
Initialize QwenVLProcessor instance.
|
||||
@@ -65,6 +66,7 @@ class QwenVLProcessor(TextProcessor):
|
||||
processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)
|
||||
self.processor = DataProcessor(
|
||||
model_path=model_name_or_path,
|
||||
enable_processor_cache=enable_processor_cache,
|
||||
tokens_per_second=config.vision_config.tokens_per_second,
|
||||
tokenizer=self.tokenizer,
|
||||
**processor_kwargs,
|
||||
@@ -271,7 +273,7 @@ class QwenVLProcessor(TextProcessor):
|
||||
|
||||
return request
|
||||
|
||||
def append_completion_tokens(self, outputs, completion_token_ids):
|
||||
def append_completion_tokens(self, multimodal_inputs, completion_token_ids):
|
||||
"""
|
||||
Append completion tokens to existing outputs.
|
||||
|
||||
@@ -279,19 +281,14 @@ class QwenVLProcessor(TextProcessor):
|
||||
outputs: Current model outputs
|
||||
completion_token_ids: completion tokens to append
|
||||
"""
|
||||
out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]}
|
||||
self.processor._add_text(completion_token_ids, out)
|
||||
|
||||
outputs["input_ids"] = np.concatenate(
|
||||
[outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0
|
||||
)
|
||||
outputs["token_type_ids"] = np.concatenate(
|
||||
[outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0
|
||||
)
|
||||
outputs["position_ids"] = np.concatenate(
|
||||
[outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64
|
||||
)
|
||||
outputs["cur_position"] = out["cur_position"]
|
||||
num_tokens = len(completion_token_ids)
|
||||
multimodal_inputs["input_ids"].extend(completion_token_ids)
|
||||
multimodal_inputs["token_type_ids"].extend([0] * num_tokens)
|
||||
|
||||
pos_ids = self.processor._compute_text_positions(multimodal_inputs["cur_position"], num_tokens)
|
||||
multimodal_inputs["position_ids"].append(pos_ids)
|
||||
multimodal_inputs["cur_position"] += num_tokens
|
||||
|
||||
def pack_outputs(self, outputs):
|
||||
"""
|
||||
@@ -303,7 +300,24 @@ class QwenVLProcessor(TextProcessor):
|
||||
Returns:
|
||||
dict: Packed output dictionary with all required fields
|
||||
"""
|
||||
if not outputs["images"]:
|
||||
outputs["images"] = None # No images case
|
||||
outputs["grid_thw"] = None # No spatial dimensions
|
||||
outputs["image_type_ids"] = None # No type IDs
|
||||
else:
|
||||
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
|
||||
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
|
||||
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
|
||||
|
||||
# Convert all outputs to numpy arrays with appropriate types
|
||||
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
|
||||
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
|
||||
outputs["position_ids"] = np.concatenate(
|
||||
outputs["position_ids"], axis=1, dtype=np.int64
|
||||
) # Concatenate position ID
|
||||
|
||||
outputs["image_patch_id"] = self.processor.image_token_id
|
||||
outputs["video_patch_id"] = self.processor.video_token_id
|
||||
outputs["position_ids"] = outputs["position_ids"].transpose(1, 0)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
|
||||
35
fastdeploy/multimodal/hasher.py
Normal file
35
fastdeploy/multimodal/hasher.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
|
||||
class MultimodalHasher:
|
||||
|
||||
@classmethod
|
||||
def hash_features(cls, obj: object) -> str:
|
||||
if isinstance(obj, np.ndarray):
|
||||
return hashlib.sha256((obj.tobytes())).hexdigest()
|
||||
|
||||
data_processor_logger.warning(
|
||||
f"Unsupported type for hashing features: {type(obj)}" + ", use pickle for serialization"
|
||||
)
|
||||
return hashlib.sha256((pickle.dumps(obj))).hexdigest()
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
|
||||
36
fastdeploy/multimodal/registry.py
Normal file
36
fastdeploy/multimodal/registry.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
class MultimodalRegistry:
|
||||
"""
|
||||
A registry for multimodal models
|
||||
"""
|
||||
|
||||
mm_models: set[str] = {
|
||||
"Ernie4_5_VLMoeForConditionalGeneration",
|
||||
"Ernie5MoeForCausalLM",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"Ernie5ForCausalLM",
|
||||
"Ernie4_5_VLMoeForProcessRewardModel",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def contains_model(cls, name: str) -> bool:
|
||||
"""
|
||||
Check if the given name exists in registry.
|
||||
"""
|
||||
return name in cls.mm_models
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
|
||||
@@ -24,13 +24,12 @@ from typing import Dict, List, Optional, Tuple
|
||||
import crcmod
|
||||
from redis import ConnectionPool
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler import utils
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.scheduler.storage import AdaptedRedis
|
||||
from fastdeploy.scheduler.workers import Task, Workers
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
from fastdeploy.utils import envs, scheduler_logger
|
||||
|
||||
|
||||
class GlobalScheduler:
|
||||
@@ -534,32 +533,33 @@ class GlobalScheduler:
|
||||
continue
|
||||
|
||||
request: ScheduledRequest = ScheduledRequest.unserialize(serialized_request)
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
|
||||
if required_total_blocks > available_blocks:
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
if required_total_blocks > available_blocks:
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
|
||||
if not envs.FD_ENABLE_MAX_PREFILL:
|
||||
if self.enable_chunked_prefill:
|
||||
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
|
||||
long_partial_requests += 1
|
||||
if long_partial_requests > self.max_long_partial_prefills:
|
||||
if not envs.FD_ENABLE_MAX_PREFILL:
|
||||
if self.enable_chunked_prefill:
|
||||
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
|
||||
long_partial_requests += 1
|
||||
if long_partial_requests > self.max_long_partial_prefills:
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
else:
|
||||
short_partial_requests += 1
|
||||
|
||||
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
else:
|
||||
short_partial_requests += 1
|
||||
|
||||
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
else:
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
|
||||
scheduled_requests.append(request)
|
||||
|
||||
|
||||
@@ -18,10 +18,9 @@ import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
from fastdeploy.utils import envs, scheduler_logger
|
||||
|
||||
|
||||
class LocalScheduler:
|
||||
@@ -247,35 +246,40 @@ class LocalScheduler:
|
||||
self.wait_request_timeout,
|
||||
)
|
||||
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
requests: List[Request] = []
|
||||
long_partial_requests, short_partial_requests = 0, 0
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
long_partial_requests, short_partial_requests = 0, 0
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
|
||||
if not envs.FD_ENABLE_MAX_PREFILL:
|
||||
if self.enable_chunked_prefill:
|
||||
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
|
||||
# 长请求
|
||||
long_partial_requests += 1
|
||||
if long_partial_requests > self.max_long_partial_prefills:
|
||||
if not envs.FD_ENABLE_MAX_PREFILL:
|
||||
if self.enable_chunked_prefill:
|
||||
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
|
||||
# 长请求
|
||||
long_partial_requests += 1
|
||||
if long_partial_requests > self.max_long_partial_prefills:
|
||||
break
|
||||
else:
|
||||
short_partial_requests += 1
|
||||
|
||||
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
|
||||
break
|
||||
else:
|
||||
short_partial_requests += 1
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
requests.append(request.raw)
|
||||
else:
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
requests.append(request.raw)
|
||||
|
||||
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
|
||||
break
|
||||
else:
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += len(requests)
|
||||
|
||||
if len(batch_ids) > 0 and len(requests) == 0:
|
||||
|
||||
@@ -137,6 +137,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"fused_gemm_epilogue",
|
||||
]
|
||||
|
||||
if self.cache_config.max_encoder_cache > 0:
|
||||
self.encoder_cache: dict[str, paddle.Tensor] = {}
|
||||
else:
|
||||
self.encoder_cache = None
|
||||
|
||||
# Sampler
|
||||
if not self.speculative_decoding:
|
||||
self.sampler = Sampler(fd_config)
|
||||
@@ -297,6 +302,185 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
schemata_key,
|
||||
)
|
||||
|
||||
def get_chunked_inputs(self, req: Request):
|
||||
"""
|
||||
Get inputs in current chunk
|
||||
"""
|
||||
prefill_start_index = req.prefill_start_index
|
||||
prefill_end_index = req.prefill_end_index
|
||||
inputs = req.multimodal_inputs
|
||||
input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index]
|
||||
token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index]
|
||||
image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end]
|
||||
images = inputs["images"][req.image_start : req.image_end]
|
||||
grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end]
|
||||
mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end]
|
||||
|
||||
return (
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
image_type_ids,
|
||||
images,
|
||||
grid_thw,
|
||||
mm_hashes,
|
||||
)
|
||||
|
||||
def batch_uncached_inputs(self, req: Request):
|
||||
"""
|
||||
Batch uncached multimodal inputs
|
||||
"""
|
||||
(input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req)
|
||||
|
||||
image_type_ids_size = grid_thw[:, 0]
|
||||
image_type_ids_split = np.cumsum(image_type_ids_size)[:-1]
|
||||
image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0)
|
||||
|
||||
images_size = np.prod(grid_thw, axis=1)
|
||||
images_split = np.cumsum(images_size)[:-1]
|
||||
images_lst = np.array_split(images, images_split, axis=0)
|
||||
|
||||
assert len(image_type_ids_lst) == len(
|
||||
mm_hashes
|
||||
), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}"
|
||||
assert len(images_lst) == len(
|
||||
mm_hashes
|
||||
), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}"
|
||||
|
||||
uncached_image_type_ids = []
|
||||
uncached_images = []
|
||||
uncached_grid_thw = []
|
||||
uncached_mm_hashes = []
|
||||
for i, mm_hash in enumerate(mm_hashes):
|
||||
if mm_hash in self.encoder_cache:
|
||||
continue
|
||||
uncached_image_type_ids.append(image_type_ids_lst[i])
|
||||
uncached_images.append(images_lst[i])
|
||||
uncached_grid_thw.append(grid_thw[i])
|
||||
uncached_mm_hashes.append(mm_hash)
|
||||
|
||||
uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
|
||||
uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
|
||||
if len(uncached_mm_hashes) > 0:
|
||||
uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64)
|
||||
uncached_images = paddle.to_tensor(
|
||||
np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
|
||||
)
|
||||
uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64)
|
||||
|
||||
return (
|
||||
uncached_input_ids,
|
||||
uncached_token_type_ids,
|
||||
uncached_image_type_ids,
|
||||
uncached_images,
|
||||
uncached_grid_thw,
|
||||
uncached_mm_hashes,
|
||||
)
|
||||
|
||||
def scatter_and_cache_features(self, image_features, inputs):
|
||||
"""
|
||||
Split batched image features and cache them
|
||||
"""
|
||||
merge_size = 2
|
||||
grid_thw = inputs["grid_thw"]
|
||||
mm_hashes = inputs["mm_hashes"]
|
||||
image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist()
|
||||
image_features_lst = paddle.split(image_features, image_features_size, axis=0)
|
||||
|
||||
assert len(image_features_lst) == len(
|
||||
mm_hashes
|
||||
), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}"
|
||||
for i, mm_hash in enumerate(mm_hashes):
|
||||
self.encoder_cache[mm_hash] = image_features_lst[i].cpu()
|
||||
|
||||
def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict):
|
||||
"""
|
||||
Apply multimodal inputs to share_inputs
|
||||
- add image_features, extract and cache vision features from model
|
||||
- add rope_emb, rotate position embeddings
|
||||
"""
|
||||
if self.encoder_cache:
|
||||
evict_mm_hashes = request.get("evict_mm_hashes", None)
|
||||
if evict_mm_hashes:
|
||||
for mm_hash in evict_mm_hashes:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
inputs = request.multimodal_inputs
|
||||
if request.with_image:
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
multi_vision_inputs["images_lst"].append(
|
||||
inputs["images"][request.image_start : request.image_end].cuda()
|
||||
)
|
||||
multi_vision_inputs["grid_thw_lst"].extend(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["cu_seqlens"].extend(
|
||||
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["vit_position_ids_lst"].extend(
|
||||
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
else:
|
||||
vision_inputs = inputs
|
||||
if self.encoder_cache:
|
||||
(
|
||||
vision_inputs["input_ids"],
|
||||
vision_inputs["token_type_ids"],
|
||||
vision_inputs["image_type_ids"],
|
||||
vision_inputs["images"],
|
||||
vision_inputs["grid_thw"],
|
||||
vision_inputs["mm_hashes"],
|
||||
) = self.batch_uncached_inputs(request)
|
||||
if len(vision_inputs["mm_hashes"]) > 0:
|
||||
# uncached multimodal inputs exist
|
||||
image_features = self.extract_vision_features(vision_inputs)
|
||||
self.scatter_and_cache_features(image_features, vision_inputs)
|
||||
|
||||
full_image_features_lst = []
|
||||
for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]:
|
||||
feature = self.encoder_cache[mm_hash].cuda()
|
||||
full_image_features_lst.append(feature)
|
||||
image_features = paddle.concat(full_image_features_lst, axis=0)
|
||||
else:
|
||||
(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
image_type_ids,
|
||||
images,
|
||||
grid_thw,
|
||||
mm_hashes,
|
||||
) = self.get_chunked_inputs(request)
|
||||
vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64)
|
||||
vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
|
||||
vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
|
||||
vision_inputs["images"] = paddle.to_tensor(
|
||||
images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
|
||||
)
|
||||
vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64)
|
||||
vision_inputs["mm_hashes"] = mm_hashes
|
||||
|
||||
image_features = self.extract_vision_features(vision_inputs)
|
||||
|
||||
# part of the first image may be already cached
|
||||
if "ernie" in self.model_config.model_type:
|
||||
actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id)
|
||||
elif "qwen" in self.model_config.model_type:
|
||||
actual_image_token_num = paddle.sum(
|
||||
vision_inputs["input_ids"] == vision_inputs["image_patch_id"]
|
||||
) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"])
|
||||
else:
|
||||
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
|
||||
self.share_inputs["image_features"] = image_features[-actual_image_token_num:]
|
||||
else:
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
position_ids = request.multimodal_inputs["position_ids"]
|
||||
rope_3d_position_ids["position_ids_idx"].append(request.idx)
|
||||
rope_3d_position_ids["position_ids_lst"].append(position_ids)
|
||||
rope_3d_position_ids["position_ids_offset"].append(
|
||||
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
|
||||
)
|
||||
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
|
||||
"""
|
||||
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
@@ -326,51 +510,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
prefill_end_index = request.prefill_end_index
|
||||
length = prefill_end_index - prefill_start_index
|
||||
if self.enable_mm:
|
||||
inputs = request.multimodal_inputs
|
||||
if request.with_image:
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
multi_vision_inputs["images_lst"].append(
|
||||
inputs["images"][request.image_start : request.image_end].cuda()
|
||||
)
|
||||
multi_vision_inputs["grid_thw_lst"].extend(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["cu_seqlens"].extend(
|
||||
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
multi_vision_inputs["vit_position_ids_lst"].extend(
|
||||
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
else:
|
||||
vision_inputs = {}
|
||||
vision_inputs["input_ids"] = paddle.to_tensor(
|
||||
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["token_type_ids"] = paddle.to_tensor(
|
||||
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["image_type_ids"] = paddle.to_tensor(
|
||||
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
|
||||
dtype=paddle.int64,
|
||||
)
|
||||
vision_inputs["images"] = paddle.to_tensor(
|
||||
inputs["images"][request.image_start : request.image_end],
|
||||
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
|
||||
)
|
||||
vision_inputs["grid_thw"] = paddle.to_tensor(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
|
||||
)
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
|
||||
else:
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
position_ids = request.multimodal_inputs["position_ids"]
|
||||
rope_3d_position_ids["position_ids_idx"].append(idx)
|
||||
rope_3d_position_ids["position_ids_lst"].append(position_ids)
|
||||
rope_3d_position_ids["position_ids_offset"].append(
|
||||
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
|
||||
)
|
||||
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
|
||||
self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids)
|
||||
|
||||
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
|
||||
# Enable thinking
|
||||
|
||||
@@ -653,6 +653,12 @@ def parse_args():
|
||||
help="Flag to specify dtype of lm_head as FP32",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_encoder_cache",
|
||||
type=int,
|
||||
help="Maximum encoder cache tokens(use 0 to disable).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache-transfer-protocol",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user