[Feature] mm support prefix cache (#4134)

* support mm prefix caching

* update code

* fix mm_hashes

* support encoder cache

* add encoder cache

* update code

* update encoder cache

* fix features bug

* fix worker bug

* support processor cache, need to optimize yet

* refactor multimodal data cache

* update code

* update code

* update v1 scheduler

* update code

* update code

* update codestyle

* support turn off processor cache and encoder cache

* update pre-commit

* fix code

* solve review

* update code

* update code

* update test case

* set processor cache in GiB

* update test case

* support mm prefix caching for qwen model

* fix code style check

* update pre-commit

* fix unit test

* fix unit test

* add ci test case

* fix rescheduled bug

* change text_after_process to prompt_tokens

* fix unit test

* fix chat template

* change model path

* [EP] fix adapter bugs (#4572)

* Update expert_service.py

* Update common_engine.py

* Update expert_service.py

* fix v1 hang bug (#4573)

* fix import image_ops error on some platforms (#4559)

* [CLI]Update parameters in bench latecy cli tool and fix collect-env cli tool (#4558)

* add collect-env

* del files

* [Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578)

* add new branch for sot

* reorder

* fix batch bug

* [XPU]Moe uses a new operator (#4585)

* [XPU]Moe uses a new operator

* [XPU]Moe uses a new operator

* update response

* [Feature] Support Paddle-OCR (#4396)

* init

* update code

* fix code style & disable thinking

* adapt for common_engine.update_mm_requests_chunk_size

* use 3d rope

* use flash_attn_unpadded

* opt siglip

* update to be compatible with the latest codebase

* fix typo

* optim OCR performance

* fix bug

* fix bug

* fix bug

* fix bug

* normlize name

* modify xpu rope

* revert logger

* fix bug

* fix bug

* fix bug

* support default_v1

* optim performance

* fix bug

---------

Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>

* [DataProcessor] add reasoning_tokens into usage info (#4520)

* add reasoning_tokens into usage info initial commit

* add unit tests

* modify unit test

* modify and add unit tests

* fix unit test

* move steam usage to processor

* modify processor

* modify test_logprobs

* modify test_logprobs.py

* modify stream reasoning tokens accumulation

* fix unit test

* perf: Optimize task queue communication from engine to worker (#4531)

* perf: Optimize task queue communication from engine to worker

* perf: get_tasks to numpy

* perf: get_tasks remove to_numpy

* fix: request & replace ENV

* remove test_e2w_perf.py

* fix code style

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* Clean up ports after processing results (#4587)

* [CI] Add /re-run command in PR comments to restart failed CI workflows (#4593)

* [Others] api server exits when worker process is dead (#3271)

* [fix] fix terminal hangs when worker process is dead

* [chore] change sleep time of monitor

* [chore] remove redundant comments

* update docs

---------

Co-authored-by: ApplEOFDiscord <wwy640130@163.com>
Co-authored-by: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
Co-authored-by: yinwei <yinwei_hust@163.com>
Co-authored-by: JYChen <zoooo0820@qq.com>
Co-authored-by: qwes5s5 <45442318+qwes5s5@users.noreply.github.com>
Co-authored-by: Ryan <zihaohuang@aliyun.com>
Co-authored-by: yyssys <atyangshuang@foxmail.com>
Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>
Co-authored-by: kxz2002 <115912648+kxz2002@users.noreply.github.com>
Co-authored-by: SunLei <sunlei5788@gmail.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
This commit is contained in:
kevin
2025-10-27 17:39:51 +08:00
committed by GitHub
parent a4fb3d4ff0
commit 8aab4e367f
40 changed files with 1741 additions and 545 deletions

View File

@@ -55,6 +55,8 @@ When using FastDeploy to deploy models (including offline inference and service
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
| ```load_choices``` | `str` | By default, the "default" loader is used for weight loading. To load Torch weights or enable weight acceleration, "default_v1" must be used.|
| ```max_encoder_cache``` | `int` | Maximum number of tokens in the encoder cache (use 0 to disable). |
| ```max_processor_cache``` | `int` | Maximum number of bytes(in GiB) in the processor cache (use 0 to disable). |
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?

View File

@@ -13,7 +13,7 @@ By default, logs are stored in the `log` directory under the execution path. To
* `fastdeploy.log` : Records configuration information during instance startup, as well as request and response details during runtime.
* `workerlog.*` : Tracks model loading progress and inference operator errors. Each GPU card has a corresponding file.
* `worker_process.log` : Logs engine inference data for each iteration.
* `prefix_cache_manager.log` : Records KV Cache logical index allocation for each request and cache hit status.
* `cache_manager.log` : Records KV Cache logical index allocation for each request and cache hit status.
* `launch_worker.log` : Logs model startup information and error messages.
* `gpu_worker.log` : Records KV Cache block count information during profiling.
* `gpu_model_runner.log` : Contains model details and loading time.

View File

@@ -53,6 +53,8 @@
| ```tool_call_parser``` | `str` | 指定要使用的function call解析器以便从模型输出中抽取 function call内容|
| ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径以便注册不在代码库中的parserparser中代码格式需遵循代码库中格式|
| ```load_choices``` | `str` | 默认使用"default" loader进行权重加载加载torch权重/权重加速需开启 "default_v1"|
| ```max_encoder_cache``` | `int` | 编码器缓存的最大token数使用0表示禁用。 |
| ```max_processor_cache``` | `int` | 处理器缓存的最大字节数以GiB为单位使用0表示禁用。 |
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?

View File

@@ -13,7 +13,7 @@ FastDeploy 在部署过程中,会产生如下日志文件,各日志含义说
* `fastdeploy.log` : 记录当前实例启动的各个 config 的信息,运行中记录用户请求的 request 及 response 信息
* `workerlog.*` : 记录模型启动加载进度及推理算子报错信息,每个卡对应一个文件
* `worker_process.log` : 记录引擎每一轮推理的数据
* `prefix_cache_manager.log` : 记录每一个请求分配 KV Cache 的逻辑索引,以及当前请求的命中情况
* `cache_manager.log` : 记录每一个请求分配 KV Cache 的逻辑索引,以及当前请求的命中情况
* `launch_worker.log` : 记录模型启动信息及报错信息
* `gpu_worker.log` : 记录 profile 时计算 KV Cache block 数目的信息
* `gpu_model_runner.log` : 当前的模型信息及加载时间

View File

@@ -18,7 +18,7 @@ from enum import Enum
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
logger = get_logger("prefix_cache_manager", "cache_manager.log")
DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = {

View File

@@ -17,7 +17,7 @@
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
logger = get_logger("prefix_cache_manager", "cache_manager.log")
class CacheMetrics:

View File

@@ -0,0 +1,163 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import pickle
import threading
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Tuple
import numpy as np
import zmq
from fastdeploy import envs
from fastdeploy.engine.request import ImagePosition
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "cache_manager.log")
class MultimodalLRUCache(ABC):
"""
General lru cache for multimodal data
"""
def __init__(self, max_cache_size):
self.cache = OrderedDict()
self.current_cache_size = 0
self.max_cache_size = max_cache_size
def apply_cache(self, mm_hashes: list[str], mm_items: list[Any]) -> list[str]:
"""
apply data cache, return evicted data
"""
assert len(mm_hashes) == len(mm_items), "mm_hashes and mm_items should have same length"
evicted_hashes = []
for idx in range(len(mm_hashes)):
if mm_hashes[idx] in self.cache:
self.cache.move_to_end(mm_hashes[idx])
else:
item_size = self.get_item_size(mm_items[idx])
if self.current_cache_size + item_size >= self.max_cache_size:
if item_size > self.max_cache_size:
# cannot be inserted even if we clear all cached data, skip it directly
continue
needed = item_size - (self.max_cache_size - self.current_cache_size)
evicted_hashes.extend(self.evict_cache(needed))
self.cache[mm_hashes[idx]] = mm_items[idx]
self.current_cache_size += item_size
return evicted_hashes
def evict_cache(self, needed: int) -> list[str]:
"""
evict data cache with needed size
"""
reduced_size, evicted_hashes = 0, []
while reduced_size < needed and len(self.cache):
mm_hash, mm_item = self.cache.popitem(last=False)
evicted_hashes.append(mm_hash)
reduced_size += self.get_item_size(mm_item)
self.current_cache_size -= self.get_item_size(mm_item)
return evicted_hashes
def get_cache(self, mm_hashes: list[str]) -> list[Any]:
"""
get cached data correspond to given hash values
"""
mm_items = []
for mm_hash in mm_hashes:
if mm_hash not in self.cache:
mm_items.append(None)
continue
mm_items.append(self.cache[mm_hash])
return mm_items
def clear_cache(self):
"""
clear all cached data
"""
evicted_hashes = list(self.cache.keys())
self.cache.clear()
self.current_cache_size = 0
return evicted_hashes
@abstractmethod
def get_item_size(self, item: Any) -> int:
raise NotImplementedError("Subclasses must define how to get size of an item")
class EncoderCacheManager(MultimodalLRUCache):
"""
EncoderCacheManager is used to cache image features
"""
def __init__(self, max_encoder_cache):
super().__init__(max_encoder_cache)
def get_item_size(self, item: ImagePosition) -> int:
return item.length
class ProcessorCacheManager(MultimodalLRUCache):
"""
ProcessorCacheManager is used to cache processed data
"""
def __init__(self, max_processor_cache):
super().__init__(max_processor_cache)
self.context = zmq.Context()
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.SNDHWM, int(envs.FD_ZMQ_SNDHWM))
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind("ipc:///dev/shm/processor_cache.ipc")
self.poller = zmq.Poller()
self.poller.register(self.router, zmq.POLLIN)
self.handler_thread = threading.Thread(target=self.cache_request_handler, daemon=True)
self.handler_thread.start()
def get_item_size(self, item: Tuple[np.ndarray, dict]) -> int:
return item[0].nbytes
def cache_request_handler(self):
try:
while True:
events = dict(self.poller.poll())
if self.router in events:
client, _, content = self.router.recv_multipart()
req = pickle.loads(content)
if isinstance(req, tuple):
# apply cache request, in format of (mm_hashes, mm_items)
self.apply_cache(req[0], req[1])
logger.info(f"Apply processor cache of mm_hashes: {req[0]}")
else:
# get cache request
resp = self.get_cache(req)
logger.info(f"Get processor cache of mm_hashes: {req}")
self.router.send_multipart([client, b"", pickle.dumps(resp)])
except Exception as e:
logger.error(f"Error happened while handling processor cache request: {e}")

View File

@@ -14,8 +14,10 @@
# limitations under the License.
"""
import hashlib
import heapq
import os
import pickle
import subprocess
import sys
import threading
@@ -35,7 +37,7 @@ from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTre
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
logger = get_logger("prefix_cache_manager", "cache_manager.log")
class PrefixCacheManager:
@@ -575,31 +577,18 @@ class PrefixCacheManager:
"""
try:
req_id = task.request_id
block_tables = task.block_tables
last_node, num_cached_tokens = self.cache_info[req_id]
if isinstance(task.prompt_token_ids, np.ndarray):
prompt_token_ids = task.prompt_token_ids.tolist()
else:
prompt_token_ids = task.prompt_token_ids
input_ids = prompt_token_ids + task.output_token_ids
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
self.leaf_req_map[last_node].remove(req_id)
with self.request_release_lock:
current_time = time.time()
leaf_node = self.build_path(
req_id=req_id,
current_time=current_time,
input_ids=input_ids,
left_input_ids=left_input_ids,
gpu_block_ids=gpu_extra_block_ids,
leaf_node = self.mm_build_path(
request=task,
num_computed_tokens=num_computed_tokens,
block_size=block_size,
last_node=last_node,
reverved_dec_block_num=0,
num_cached_tokens=num_cached_tokens,
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
@@ -636,10 +625,9 @@ class PrefixCacheManager:
prompt_token_ids = task.prompt_token_ids.tolist()
else:
prompt_token_ids = task.prompt_token_ids
input_ids = prompt_token_ids + task.output_token_ids
req_id = task.request_id
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
input_token_num = len(input_ids)
input_token_num = len(prompt_token_ids + task.output_token_ids)
common_block_ids = []
# 1. match block
(
@@ -649,7 +637,7 @@ class PrefixCacheManager:
match_block_node,
gpu_match_token_num,
cpu_match_token_num,
) = self.match_block(req_id, input_ids, block_size)
) = self.mm_match_block(task, block_size)
# update matched node info
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
@@ -1145,6 +1133,173 @@ class PrefixCacheManager:
"""
return hash(tuple(block))
def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
"""
Retrieves additional hash keys for block identification.
Args:
request: The input request object containing the data to be processed.
start_idx (int): The starting index of the block segment to hash.
end_idx (int): The ending index of the block segment to hash.
mm_idx: The multimodal index identifier for specialized content handling.
Returns:
mm_idx: next multimodal index
hash_keys: A list of additional hash keys
"""
hash_keys = []
mm_inputs = request.multimodal_inputs
if (
mm_inputs is None
or "mm_positions" not in mm_inputs
or "mm_hashes" not in mm_inputs
or len(mm_inputs["mm_positions"]) == 0
):
return mm_idx, hash_keys
assert start_idx < end_idx, f"start_idx {start_idx} >= end_idx {end_idx}"
assert (
start_idx >= 0 and start_idx < request.num_total_tokens
), f"start_idx {start_idx} out of range {request.num_total_tokens}"
assert (
end_idx >= 0 and end_idx <= request.num_total_tokens
), f"end_idx {end_idx} out of range {request.num_total_tokens}"
assert len(mm_inputs["mm_positions"]) == len(
mm_inputs["mm_hashes"]
), f"mm_positions {len(mm_inputs['mm_positions'])} != mm_hashes {len(mm_inputs['mm_hashes'])}"
assert mm_idx >= 0 and mm_idx < len(
mm_inputs["mm_hashes"]
), f"mm_idx {mm_idx} out of range {len(mm_inputs['mm_hashes'])}"
if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length < start_idx:
# non images in current block
return mm_idx, hash_keys
for img_idx in range(mm_idx, len(mm_inputs["mm_positions"])):
image_offset = mm_inputs["mm_positions"][img_idx].offset
image_length = mm_inputs["mm_positions"][img_idx].length
if image_offset + image_length < start_idx:
# image before block
continue
elif image_offset >= end_idx:
# image after block
return img_idx, hash_keys
elif image_offset + image_length > end_idx:
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
return img_idx, hash_keys
else:
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
return len(mm_inputs["mm_positions"]) - 1, hash_keys
def hash_block_features(self, input_ids, extra_keys: list = []):
"""
calculate hash value of a block with additional keys
Args:
input_ids: Input token IDs
extra_keys: Additional keys for block identification
"""
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
def mm_match_block(self, request, block_size):
"""
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
Args:
request: The multimodal request object containing prompt and output token IDs.
block_size (int): The size of each token block for matching and processing.
Returns:
tuple: A tuple containing:
- match_gpu_block_ids (list): List of block IDs matched in GPU cache
- match_cpu_block_ids (list): List of block IDs matched in CPU cache
- swap_node_ids (list): List of node IDs scheduled for GPU-CPU swapping
- current_match_node: The last matched node in the radix tree traversal
- gpu_match_token_num (int): Total number of tokens matched in GPU cache
- cpu_match_token_num (int): Total number of tokens matched in CPU cache
"""
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
input_ids = prompt_token_ids + request.output_token_ids
total_token_num = len(input_ids)
current_match_node = self.radix_tree_root # 从根节点开始搜
match_gpu_block_ids = []
match_cpu_block_ids = []
match_node_ids = []
mm_idx = 0
match_token_num = 0
cpu_match_token_num = 0
gpu_match_token_num = 0
swap_node_ids = []
matche_nodes = []
has_modified_gpu_lru_leaf_heap = False
has_modified_cpu_lru_leaf_heap = False
with self.cache_status_lock:
while match_token_num < total_token_num:
token_block = input_ids[match_token_num : match_token_num + block_size]
token_num = len(token_block)
if token_num != block_size:
break
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=request,
start_idx=match_token_num,
end_idx=match_token_num + block_size,
mm_idx=mm_idx,
)
hash_value = self.hash_block_features(token_block, extra_keys)
if hash_value in current_match_node.children:
child = current_match_node.children[hash_value]
matche_nodes.append(child)
match_node_ids.append(child.node_id)
if child in self.gpu_lru_leaf_set:
self.gpu_lru_leaf_set.remove(child)
self.gpu_lru_leaf_heap.remove(child)
has_modified_gpu_lru_leaf_heap = True
elif child in self.cpu_lru_leaf_set:
self.cpu_lru_leaf_set.remove(child)
self.cpu_lru_leaf_heap.remove(child)
has_modified_cpu_lru_leaf_heap = True
if child.has_in_gpu:
match_gpu_block_ids.append(child.block_id)
gpu_match_token_num += block_size
else:
if child.cache_status == CacheStatus.SWAP2CPU:
logger.info(
f"match_block: req_id {request.request_id} matched node"
+ f" {child.node_id} which is being SWAP2CPU"
)
child.cache_status = CacheStatus.GPU
match_gpu_block_ids.append(child.block_id)
gpu_match_token_num += block_size
elif child.cache_status == CacheStatus.CPU:
child.cache_status = CacheStatus.SWAP2GPU
match_cpu_block_ids.append(child.block_id)
cpu_match_token_num += block_size
swap_node_ids.append(child.node_id)
match_token_num = match_token_num + block_size
current_match_node = child
else:
break
if has_modified_gpu_lru_leaf_heap:
heapq.heapify(self.gpu_lru_leaf_heap)
if has_modified_cpu_lru_leaf_heap:
heapq.heapify(self.cpu_lru_leaf_heap)
logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
return (
match_gpu_block_ids,
match_cpu_block_ids,
swap_node_ids,
current_match_node,
gpu_match_token_num,
cpu_match_token_num,
)
def match_block(self, req_id, input_ids, block_size):
"""
Args:
@@ -1241,6 +1396,86 @@ class PrefixCacheManager:
node.req_id_set.add(req_id)
node = node.parent
def mm_build_path(self, request, num_computed_tokens, block_size, last_node, num_cached_tokens):
"""
Constructs a caching path in radix tree for multimodal requests by processing computed tokens.
Args:
request: The inference request object containing:
- prompt_token_ids: Original input tokens (List[int] or np.ndarray)
- output_token_ids: Generated tokens (List[int])
- mm_positions: Optional image positions for multimodal content
num_computed_tokens: Total tokens processed so far (cached + newly computed)
block_size: Fixed size of token blocks (must match cache configuration)
last_node: The deepest existing BlockNode in the radix tree for this request
num_cached_tokens: Number of tokens already cached
Returns:
BlockNode: The new deepest node in the constructed path
"""
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
input_ids = prompt_token_ids + request.output_token_ids
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
if num_cached_tokens == can_cache_computed_tokens:
return last_node
mm_idx = 0
node = last_node
unique_node_ids = []
new_last_node = last_node
has_unfilled_block = False
current_time = time.time()
input_hash_value = self.hash_block_features(input_ids)
gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy()
for i in range(num_cached_tokens, can_cache_computed_tokens, block_size):
current_block = input_ids[i : i + block_size]
current_block_size = len(current_block) # 最后一个block可能没填满
if current_block_size != block_size:
has_unfilled_block = True
else:
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=request,
start_idx=i,
end_idx=i + block_size,
mm_idx=mm_idx,
)
hash_value = self.hash_block_features(current_block, extra_keys)
allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id)
new_last_node = BlockNode(
node_id,
input_ids,
input_hash_value,
node.depth + 1,
allocated_block_id,
current_block_size,
hash_value,
current_time,
parent=node,
shared_count=1,
reverved_dec_block_ids=[],
)
new_last_node.req_id_set.add(request.request_id)
self.node_map[node_id] = new_last_node
node.children[hash_value] = new_last_node
node = new_last_node
reverved_dec_block_ids = []
if has_unfilled_block is True:
reverved_dec_block_ids.append(gpu_block_ids.pop(0))
if new_last_node == self.radix_tree_root:
self.unfilled_req_block_map[request.request_id] = reverved_dec_block_ids
else:
new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids)
logger.info(f"build_path: allocate unique node ids {unique_node_ids} for req_id {request.request_id}")
return new_last_node
def build_path(
self,
req_id,

View File

@@ -1137,6 +1137,8 @@ class CacheConfig:
enc_dec_block_num (int): Number of encoder-decoder blocks.
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1.
enable_prefix_caching (bool): Enable prefix caching.
max_encoder_cache(int): Maximum number of tokens in the encoder cache.
max_processor_cache(int): Maximum number of bytes in the processor cache.
"""
self.block_size = 64
self.gpu_memory_utilization = 0.9
@@ -1157,6 +1159,8 @@ class CacheConfig:
self.enable_ssd_cache = False
self.cache_queue_port = None
self.swap_space = None
self.max_encoder_cache = None
self.max_processor_cache = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
@@ -1440,7 +1444,7 @@ class FDConfig:
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
if current_platform.is_xpu():
self.max_prefill_batch = 1
if self.model_config is not None and self.model_config.enable_mm:
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
else:
self.max_prefill_batch = self.scheduler_config.max_num_seqs
@@ -1500,7 +1504,7 @@ class FDConfig:
self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs)
self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size)
if self.model_config is not None and self.model_config.enable_mm:
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.cache_config.enable_prefix_caching = False
if (
@@ -1513,6 +1517,20 @@ class FDConfig:
else:
self.structured_outputs_config.guided_decoding_backend = "xgrammar"
if self.model_config.enable_mm:
if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0:
self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens
elif self.cache_config.max_encoder_cache != 0:
if self.cache_config.max_encoder_cache < self.scheduler_config.max_num_batched_tokens:
logger.warning(
f"max_encoder_cache{self.cache_config.max_encoder_cache} is less than "
f"max_num_batched_tokens{self.scheduler_config.max_num_batched_tokens}, "
f"set to max_num_batched_tokens."
)
self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens
else:
self.cache_config.max_encoder_cache = 0
# Adjustment GraphOptConfig
if self.load_config is not None and self.load_config.dynamic_load_weight is True:
self.graph_opt_config.graph_opt_level = 0

View File

@@ -123,6 +123,14 @@ class EngineArgs:
"""
Limitation of numbers of multi-modal data.
"""
max_encoder_cache: int = -1
"""
Maximum number of tokens in the encoder cache.
"""
max_processor_cache: float = -1
"""
Maximum number of bytes(in GiB) in the processor cache.
"""
reasoning_parser: str = None
"""
specifies the reasoning parser to use for extracting reasoning content from the model output
@@ -526,6 +534,18 @@ class EngineArgs:
type=json.loads,
help="Additional keyword arguments for the multi-modal processor.",
)
model_group.add_argument(
"--max-encoder-cache",
default=EngineArgs.max_encoder_cache,
type=int,
help="Maximum encoder cache tokens(use 0 to disable).",
)
model_group.add_argument(
"--max-processor-cache",
default=EngineArgs.max_processor_cache,
type=float,
help="Maximum processor cache bytes(use 0 to disable).",
)
model_group.add_argument(
"--enable-mm",
action=DeprecatedOptionWarning,

View File

@@ -634,13 +634,9 @@ class EngineService:
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
if self.cfg.model_config.enable_mm:
available_blocks = self.resource_manager.available_block_num()
else:
available_blocks = self.cfg.cache_config.max_block_num_per_seq
tasks = self.scheduler.get_requests(
available_blocks=available_blocks,
available_blocks=self.cfg.cache_config.max_block_num_per_seq,
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.model_config.max_model_len,

View File

@@ -528,6 +528,7 @@ class LLMEngine:
f" --load_choices {self.cfg.load_config.load_choices}"
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
f" --ips {ips}"
f" --max_encoder_cache {self.cfg.cache_config.max_encoder_cache}"
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
f" --runner {self.cfg.model_config.runner}"
f" --convert {self.cfg.model_config.convert}"

View File

@@ -46,6 +46,12 @@ class RequestType(Enum):
EXTEND = 3
@dataclass
class ImagePosition:
offset: int = 0
length: int = 0
@dataclass
class Request:
def __init__(

View File

@@ -28,10 +28,21 @@ import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.cache_manager.multimodal_cache_manager import (
EncoderCacheManager,
ProcessorCacheManager,
)
from fastdeploy.engine.request import (
ImagePosition,
Request,
RequestOutput,
RequestStatus,
RequestType,
)
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger
@@ -175,6 +186,15 @@ class ResourceManagerV1(ResourceManager):
self.need_block_num_map = dict()
self.encoder_cache = None
if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0:
self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache)
self.processor_cache = None
if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0:
max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024)
self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes)
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
@@ -273,6 +293,44 @@ class ResourceManagerV1(ResourceManager):
break
return can_schedule
def _update_mm_hashes(self, request):
if request.multimodal_inputs is None:
return
inputs = request.multimodal_inputs
if (
inputs.get("images", None) is not None
and inputs.get("image_patch_id", None) is not None
and inputs.get("grid_thw", None) is not None
and len(inputs["grid_thw"]) != 0
):
grid_thw = []
new_mm_positions, new_mm_hashes = [], []
image_st = 0
for idx, one in enumerate(inputs["grid_thw"]):
t, h, w = one[0], one[1], one[2]
if t == 1:
grid_thw.append(one)
new_mm_positions.append(inputs["mm_positions"][idx])
new_mm_hashes.append(inputs["mm_hashes"][idx])
image_st += h * w
else:
grid_thw.extend([[2, h, w]] * (t // 2))
token_st = inputs["mm_positions"][idx].offset
for _ in range(t // 2):
new_mm_positions.append(ImagePosition(token_st, h * w // 4))
# videos are split into patches every 2 frames, need to rehash
new_mm_hashes.append(
MultimodalHasher.hash_features(inputs["images"][image_st : image_st + 2 * h * w])
)
image_st += 2 * h * w
token_st += h * w // 4
inputs["mm_positions"] = new_mm_positions
inputs["mm_hashes"] = new_mm_hashes
else:
inputs["mm_positions"] = []
inputs["mm_hashes"] = []
def _get_num_new_tokens(self, request, token_budget):
# TODO: set condition to new _get_num_new_tokens
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
@@ -333,11 +391,12 @@ class ResourceManagerV1(ResourceManager):
if request.multimodal_img_boundaries is None:
grid_thw = []
for one in inputs["grid_thw"]:
if one[0] == 1:
for idx, one in enumerate(inputs["grid_thw"]):
t, h, w = one[0], one[1], one[2]
if t == 1:
grid_thw.append(one)
else:
grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2))
grid_thw.extend([[2, h, w]] * (t // 2))
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
if current_platform.is_xpu():
@@ -398,6 +457,11 @@ class ResourceManagerV1(ResourceManager):
request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1))
request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1))
cur_mm_hashes = inputs["mm_hashes"][request.num_image_start : request.num_image_end]
cur_mm_positions = inputs["mm_positions"][request.num_image_start : request.num_image_end]
if self.encoder_cache:
request.evict_mm_hashes = self.encoder_cache.apply_cache(cur_mm_hashes, cur_mm_positions)
# Compatible with scenarios without images and videos.
return num_new_tokens
@@ -553,6 +617,7 @@ class ResourceManagerV1(ResourceManager):
break
request = self.waiting[0]
if request.status == RequestStatus.WAITING:
self._update_mm_hashes(request)
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
if (

View File

@@ -17,7 +17,6 @@
import os
import time
import uuid
from copy import deepcopy
from pathlib import Path
from typing import List, Literal, Optional, Union
from urllib.parse import urlparse
@@ -29,6 +28,7 @@ from openai.types.chat import (
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
from typing_extensions import Required, TypeAlias, TypedDict
from fastdeploy.multimodal.image import ImageMediaIO
@@ -36,6 +36,17 @@ from fastdeploy.multimodal.video import VideoMediaIO
from fastdeploy.utils import api_server_logger
class CustomChatCompletionContentPartImageParam(TypedDict, total=False):
"""Custom Image URL object"""
type: Required[Literal["image_url"]]
"""The type of the content part."""
image_url: Optional[ImageURL]
uuid: Optional[str]
class VideoURL(TypedDict, total=False):
"""Video URL object"""
@@ -46,14 +57,17 @@ class VideoURL(TypedDict, total=False):
class CustomChatCompletionContentPartVideoParam(TypedDict, total=False):
"""Custom Video URL object"""
video_url: Required[VideoURL]
type: Required[Literal["video_url"]]
"""The type of the content type."""
"""The type of the content part."""
video_url: Optional[VideoURL]
uuid: Optional[str]
CustomChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam,
CustomChatCompletionContentPartImageParam,
CustomChatCompletionContentPartVideoParam,
]
@@ -77,7 +91,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam]
class MultiModalPartParser:
class MultimodalPartParser:
"""Multi Modal Part parser"""
def __init__(self):
@@ -139,32 +153,46 @@ def parse_content_part(mm_parser, part):
return part
if part_type == "image_url":
content = part.get("image_url", {}).get("url", None)
image = mm_parser.parse_image(content)
parsed = deepcopy(part)
del parsed["image_url"]["url"]
parsed["image"] = image
parsed["type"] = "image"
return parsed
if not part.get("image_url", None) and not part.get("uuid", None):
raise ValueError("Both image_url and uuid are missing")
if part.get("image_url", None):
url = part["image_url"]["url"]
image = mm_parser.parse_image(url)
else:
image = None
parsed = {}
parsed["type"] = "image"
parsed["data"] = image
parsed["uuid"] = part.get("uuid", None)
return parsed
if part_type == "video_url":
content = part.get("video_url", {}).get("url", None)
video = mm_parser.parse_video(content)
parsed = deepcopy(part)
del parsed["video_url"]["url"]
parsed["video"] = video
if not part.get("video_url", None) and not part.get("uuid", None):
raise ValueError("Both video_url and uuid are missing")
if part.get("video_url", None):
url = part["video_url"]["url"]
video = mm_parser.parse_video(url)
else:
video = None
parsed = {}
parsed["type"] = "video"
parsed["data"] = video
parsed["uuid"] = part.get("uuid", None)
return parsed
raise ValueError(f"Unknown content part type: {part_type}")
# TODO async
# def parse_chat_messages(messages: List[ChatCompletionMessageParam]):
def parse_chat_messages(messages):
def parse_chat_messages(messages: List[ChatCompletionMessageParam]):
"""Parse chat messages to [dict]"""
mm_parser = MultiModalPartParser()
mm_parser = MultimodalPartParser()
conversation = []
for message in messages:

View File

@@ -68,16 +68,19 @@ class EngineClient:
tool_parser=None,
enable_prefix_caching=None,
splitwise_role=None,
max_processor_cache=0,
):
model_config = ModelConfig({"model": model_name_or_path})
self.enable_mm = model_config.enable_mm
enable_processor_cache = self.enable_mm and max_processor_cache > 0
input_processor = InputPreprocessor(
model_config,
reasoning_parser,
limit_mm_per_prompt,
mm_processor_kwargs,
tool_parser,
enable_processor_cache,
)
self.enable_mm = model_config.enable_mm
self.enable_logprob = enable_logprob
self.reasoning_parser = reasoning_parser
self.data_processor = input_processor.create_processor()

View File

@@ -193,6 +193,7 @@ async def lifespan(app: FastAPI):
tool_parser=args.tool_call_parser,
enable_prefix_caching=args.enable_prefix_caching,
splitwise_role=args.splitwise_role,
max_processor_cache=args.max_processor_cache,
)
await engine_client.connection_manager.initialize()
app.state.dynamic_load_weight = args.dynamic_load_weight

View File

@@ -470,6 +470,8 @@ class CompletionRequest(BaseModel):
max_streaming_response_tokens: Optional[int] = None
return_token_ids: Optional[bool] = None
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None
mm_hashes: Optional[list] = None
# doc: end-completion-extra-params
def to_dict_for_infer(self, request_id=None, prompt=None):
@@ -527,6 +529,9 @@ class CompletionRequest(BaseModel):
if item is not None:
req_dict[key] = item
if self.mm_hashes is not None and len(self.mm_hashes) > 0:
req_dict["mm_hashes"] = self.mm_hashes
return req_dict
@model_validator(mode="before")
@@ -553,6 +558,9 @@ class CompletionRequest(BaseModel):
"('guided_json', 'guided_regex', 'guided_choice', 'guided_grammar')."
)
if data.get("mm_hashes", None):
assert isinstance(data["mm_hashes"], list), "`mm_hashes` must be a list."
return data
@@ -618,6 +626,8 @@ class ChatCompletionRequest(BaseModel):
prompt_token_ids: Optional[List[int]] = None
max_streaming_response_tokens: Optional[int] = None
disable_chat_template: Optional[bool] = False
mm_hashes: Optional[list] = None
completion_token_ids: Optional[List[int]] = None
# doc: end-chat-completion-extra-params
@@ -694,6 +704,9 @@ class ChatCompletionRequest(BaseModel):
if item is not None:
req_dict[key] = item
if self.mm_hashes is not None and len(self.mm_hashes) > 0:
req_dict["mm_hashes"] = self.mm_hashes
return req_dict
@model_validator(mode="before")
@@ -721,6 +734,9 @@ class ChatCompletionRequest(BaseModel):
"('guided_json', 'guided_regex', 'guided_choice', 'guided_grammar', 'structural_tag')."
)
if data.get("mm_hashes", None):
assert isinstance(data["mm_hashes"], list), "`mm_hashes` must be a list."
return data
@model_validator(mode="before")

View File

@@ -37,6 +37,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
mm_processor_kwargs=None,
reasoning_parser_obj=None,
tool_parser_obj=None,
enable_processor_cache=False,
):
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
tokenizer_path = model_name_or_path
@@ -46,6 +47,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
self.ernie4_5_processor = DataProcessor(
tokenizer_name=tokenizer_path,
image_preprocessor_name=preprocessor_path,
enable_processor_cache=enable_processor_cache,
**processor_kwargs,
)
self.ernie4_5_processor.eval()

View File

@@ -18,16 +18,20 @@
""" process.py """
import copy
import os
import pickle
from collections import defaultdict
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import zmq
from paddleformers.transformers.image_utils import ChannelDimension
from PIL import Image
from fastdeploy.engine.request import ImagePosition
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.utils import data_processor_logger
from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor
@@ -84,6 +88,7 @@ class DataProcessor:
self,
tokenizer_name: str,
image_preprocessor_name: str,
enable_processor_cache: bool = False,
spatial_conv_size: int = 2,
temporal_conv_size: int = 2,
image_min_pixels: int = 4 * 28 * 28,
@@ -102,6 +107,7 @@ class DataProcessor:
self._load_tokenizer()
self.tokenizer.ignored_index = -100
self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name)
self.enable_processor_cache = enable_processor_cache
# Convolution sizes for patch aggregation
self.spatial_conv_size = spatial_conv_size
@@ -163,10 +169,18 @@ class DataProcessor:
"""Enable evaluation mode (doesn't produce labels)."""
self.is_training = False
def text2ids(self, text, images=None, videos=None):
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
"""
Convert chat text into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
Args:
text (str): The chat text containing placeholders for images and videos.
images (list, optional): List of images to be processed and inserted at image placeholders.
videos (list, optional): List of videos to be processed and inserted at video placeholders.
image_uuid (list, optional): List of unique identifiers for each image, used for caching or hashing.
video_uuid (list, optional): List of unique identifiers for each video, used for caching or hashing.
Returns:
dict: A dictionary with keys input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels, etc.
"""
outputs = {
@@ -178,8 +192,9 @@ class DataProcessor:
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
"mm_positions": [],
"mm_hashes": [],
}
IMAGE_PLACEHOLDER = "<|image@placeholder|>"
@@ -199,17 +214,27 @@ class DataProcessor:
break
if ed == image_pos:
self._add_image(images[image_idx], outputs)
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid)
else:
# cached images are already processed
self._add_processed_image(image, outputs, uuid)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
if isinstance(item, dict):
frames = self._load_and_process_video(item["video"], item)
uuid = video_uuid[video_idx] if video_uuid else None
if not isinstance(item, tuple):
if isinstance(item, dict):
frames = self._load_and_process_video(item["video"], item)
else:
frames = self._load_and_process_video(item, {})
self._add_video(frames, outputs, uuid)
else:
frames = self._load_and_process_video(item, {})
self._add_video(frames, outputs)
# cached frames are already processed
self._add_processed_video(item, outputs, uuid)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
@@ -223,66 +248,82 @@ class DataProcessor:
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
messages = parse_chat_messages(request.get("messages"))
image_message_list = []
mm_items = []
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
content_items = msg.get("content")
if not isinstance(content_items, list):
content_items = [content_items]
for item in content_items:
if isinstance(item, dict) and item.get("type") in [
"image",
"video",
]:
image_message_list.append(item)
content = msg.get("content")
if not isinstance(content, list):
content = [content]
for item in content:
if item.get("type") in ["image", "video"]:
mm_items.append(item)
missing_hashes, missing_idx = [], []
for idx, item in enumerate(mm_items):
if not item.get("data"):
# raw data not provided, should be retrieved from processor cache
missing_hashes.append(item.get("uuid"))
missing_idx.append(idx)
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
missing_items = self.get_processor_cache(dealer, missing_hashes)
for idx in range(len(missing_items)):
if not missing_items[idx]:
raise ValueError(f"Missing item {idx} not found in processor cache")
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
images, videos = [], []
image_uuid, video_uuid = [], []
for item in mm_items:
if item.get("type") == "image":
images.append(item["data"])
image_uuid.append(item["uuid"])
elif item.get("type") == "video":
videos.append(item["data"])
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
chat_template_kwargs = request.get("chat_template_kwargs", {})
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
if len(prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
image_start_index = 0
image_message_index = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in [
self.image_start_id,
self.video_start_id,
]:
self._add_text(prompt_token_ids[image_start_index : i + 1], outputs)
image_start_index = i + 1
image_message = image_message_list[image_message_index]
if image_message["type"] == "image":
img = image_message.get("image")
if img is None:
continue
outputs["pic_cnt"] += 1
self._add_image(img, outputs)
elif image_message["type"] == "video":
video_bytes = image_message.get("video")
if video_bytes is None:
continue
frames = self._load_and_process_video(video_bytes, image_message)
outputs["video_cnt"] += 1
self._add_video(frames, outputs)
image_message_index += 1
self._add_text(prompt_token_ids[image_start_index:], outputs)
prompt = self.tokenizer.apply_chat_template(
request,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**chat_template_kwargs,
)
request["prompt_tokens"] = prompt
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx][0]
meta["thw"] = (t, h, w)
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
if self.is_training:
assert tgts, "training must give tgt !"
assert tgts, "Training must give tgt"
self._extract_labels(outputs, tgts)
return outputs
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
@@ -304,7 +345,7 @@ class DataProcessor:
outputs["position_ids"].append([start + i] * 3)
outputs["cur_position"] += len(tokens)
def _add_image(self, img, outputs: Dict) -> None:
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
img.height,
img.width,
@@ -313,6 +354,7 @@ class DataProcessor:
)[1]
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
@@ -330,10 +372,32 @@ class DataProcessor:
input_data_format=ChannelDimension.LAST,
)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(ret["image_grid_thw"])
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict) -> None:
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
_, h, w = meta["thw"]
pos_ids = self._compute_3d_positions(1, h, w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
outputs["images"].append(img)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
frames[0].height,
frames[0].width,
@@ -354,9 +418,14 @@ class DataProcessor:
input_data_format=ChannelDimension.LAST,
)
outputs["images"].append(ret["pixel_values_videos"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values_videos"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(ret["video_grid_thw"])
outputs["image_type_ids"].extend([1] * num_frames)
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
@@ -364,6 +433,24 @@ class DataProcessor:
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
t, h, w = meta["thw"]
outputs["images"].append(frames)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[t, h, w]]))
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["image_type_ids"].extend([1] * t)
pos_ids = self._compute_3d_positions(t, h, w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:
input_ids = copy.deepcopy(outputs["input_ids"])
labels = [self.tokenizer.ignored_index] * len(input_ids)
@@ -480,33 +567,22 @@ class DataProcessor:
break
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
def apply_chat_template(self, request, **kwargs):
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
"""
Convert multi-turn messages into ID sequences.
Args:
messages: Either a request dict containing 'messages' field,
or a list of message dicts directly
Returns:
List of token IDs as strings (converted from token objects)
get cache correspond to given hash values
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
req = pickle.dumps(mm_hashes)
socket.send_multipart([b"", req])
_, resp = socket.recv_multipart()
mm_items = pickle.loads(resp)
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
prompt_token_template = self.tokenizer.apply_chat_template(
request,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**kwargs,
)
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
"<|video@placeholder|>", ""
)
request["prompt_tokens"] = prompt_token_template
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}"
)
return token_ids
return mm_items
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
"""
update cache data
"""
req = pickle.dumps((mm_hashes, mm_items))
socket.send_multipart([b"", req])
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")

View File

@@ -46,6 +46,7 @@ class InputPreprocessor:
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
tool_parser: str = None,
enable_processor_cache: bool = False,
) -> None:
self.model_config = model_config
self.model_name_or_path = self.model_config.model
@@ -53,6 +54,7 @@ class InputPreprocessor:
self.limit_mm_per_prompt = limit_mm_per_prompt
self.mm_processor_kwargs = mm_processor_kwargs
self.tool_parser = tool_parser
self.enable_processor_cache = enable_processor_cache
def create_processor(self):
reasoning_parser_obj = None
@@ -104,6 +106,19 @@ class InputPreprocessor:
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
enable_processor_cache=self.enable_processor_cache,
)
elif "PaddleOCRVL" in architecture:
from fastdeploy.input.paddleocr_vl_processor import (
PaddleOCRVLProcessor,
)
self.processor = PaddleOCRVLProcessor(
config=self.model_config,
model_name_or_path=self.model_name_or_path,
limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
)
elif "PaddleOCRVL" in architecture:
from fastdeploy.input.paddleocr_vl_processor import (
@@ -126,5 +141,6 @@ class InputPreprocessor:
limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
enable_processor_cache=self.enable_processor_cache,
)
return self.processor

View File

@@ -15,17 +15,23 @@
# limitations under the License.
"""
from typing import Any, Dict, List, Tuple, Union
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import zmq
from paddleformers.transformers import AutoTokenizer
from PIL import Image
from fastdeploy.engine.request import ImagePosition
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.utils import data_processor_logger
from .image_processor import ImageProcessor
from .process_video import read_frames, sample_frames
from .process_video import sample_frames
class DataProcessor:
@@ -49,8 +55,11 @@ class DataProcessor:
def __init__(
self,
model_path: str,
enable_processor_cache: bool = False,
video_min_frames: int = 4,
video_max_frames: int = 768,
video_target_frames: int = -1,
video_fps: int = -1,
tokens_per_second: int = 2,
tokenizer=None,
**kwargs,
@@ -67,6 +76,8 @@ class DataProcessor:
"""
self.min_frames = video_min_frames
self.max_frames = video_max_frames
self.target_frames = video_target_frames
self.fps = video_fps
# Initialize tokenizer with left padding and fast tokenizer
if tokenizer is None:
@@ -75,6 +86,7 @@ class DataProcessor:
else:
self.tokenizer = tokenizer
self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor
self.enable_processor_cache = enable_processor_cache
# Convolution sizes for patch aggregation
self.spatial_conv_size = self.image_processor.merge_size
@@ -99,41 +111,7 @@ class DataProcessor:
"assistant": "Assistant: ",
}
def _pack_outputs(self, outputs):
"""
Pack and convert all output data into numpy arrays with appropriate types.
Args:
outputs (dict): Dictionary containing model outputs with keys:
- images: List of visual features
- grid_thw: List of spatial dimensions
- image_type_ids: List of content type indicators
- input_ids: List of token IDs
- token_type_ids: List of type identifiers
- position_ids: List of position embeddings
Returns:
dict: Processed outputs with all values converted to numpy arrays
"""
# Process visual outputs - stack if exists or set to None if empty
if not outputs["images"]:
outputs["images"] = None # No images case
outputs["grid_thw"] = None # No spatial dimensions
outputs["image_type_ids"] = None # No type IDs
else:
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
# Convert all outputs to numpy arrays with appropriate types
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
outputs["position_ids"] = np.concatenate(
outputs["position_ids"], axis=1, dtype=np.int64
) # Concatenate position IDs
return outputs
def text2ids(self, text, images=None, videos=None):
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
"""
Convert text with image/video placeholders into model inputs.
@@ -141,6 +119,8 @@ class DataProcessor:
text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
images: List of PIL Images corresponding to image placeholders
videos: List of video data corresponding to video placeholders
image_uuid: List of unique identifiers for each image, used for caching or hashing.
video_uuid: List of unique identifiers for each video, used for caching or hashing.
Returns:
Dict containing:
@@ -161,8 +141,10 @@ class DataProcessor:
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
"fps": [],
"mm_positions": [],
"mm_hashes": [],
}
# Define placeholders and their lengths
@@ -186,23 +168,30 @@ class DataProcessor:
break
if ed == image_pos:
outputs["pic_cnt"] += 1
self._add_image(images[image_idx], outputs)
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid)
else:
self._add_processed_image(image, outputs, uuid)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
if isinstance(item, dict):
frames, meta = self._load_and_process_video(item["video"], item)
uuid = video_uuid[video_idx] if video_uuid else None
if not isinstance(item, tuple):
if isinstance(item, dict):
frames, meta = self._load_and_process_video(item["video"], item)
else:
frames, meta = self._load_and_process_video(item, {})
self._add_video(frames, meta, outputs, uuid)
else:
frames, meta = self._load_and_process_video(item, {})
outputs["video_cnt"] += 1
self._add_video(frames, meta, outputs)
# cached frames are already processed
self._add_processed_video(item, outputs, uuid)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
return self._pack_outputs(outputs)
return outputs
def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
@@ -220,74 +209,84 @@ class DataProcessor:
Dict with same structure as text2ids() output
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"pic_cnt": 0,
"video_cnt": 0,
}
# Parse and validate chat messages
messages = parse_chat_messages(request.get("messages"))
image_message_list = [] # Store visual content messages
mm_items = []
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
# Normalize content to list format
content_items = msg.get("content")
if not isinstance(content_items, list):
content_items = [content_items]
content = msg.get("content")
if not isinstance(content, list):
content = [content]
# Collect all visual content items
for item in content_items:
if isinstance(item, dict) and item.get("type") in ["image", "video"]:
image_message_list.append(item)
for item in content:
if item.get("type") in ["image", "video"]:
mm_items.append(item)
raw_messages = request["messages"]
request["messages"] = messages
missing_hashes, missing_idx = [], []
for idx, item in enumerate(mm_items):
if not item.get("data"):
# raw data not provided, should be retrieved from processor cache
missing_hashes.append(item.get("uuid"))
missing_idx.append(idx)
prompt_token_ids = self.apply_chat_template(request)
if len(prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
request["messages"] = raw_messages
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")
vision_start_index = 0
vision_message_index = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] == self.vision_start_id:
self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs)
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
vision_start_index = i + 1
image_message = image_message_list[vision_message_index]
missing_items = self.get_processor_cache(dealer, missing_hashes)
for idx in range(len(missing_items)):
if not missing_items[idx]:
raise ValueError(f"Missing item {idx} not found in processor cache")
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
if image_message["type"] == "image":
img = image_message.get("image")
if img is None:
continue
outputs["pic_cnt"] += 1
self._add_image(img, outputs)
images, videos = [], []
image_uuid, video_uuid = [], []
for item in mm_items:
if item.get("type") == "image":
images.append(item["data"])
image_uuid.append(item["uuid"])
elif item.get("type") == "video":
videos.append(item["data"])
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
elif image_message["type"] == "video":
video_bytes = image_message.get("video")
if video_bytes is None:
continue
frames, meta = self._load_and_process_video(video_bytes, image_message)
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
outputs["video_cnt"] += 1
self._add_video(frames, meta, outputs)
chat_template_kwargs = request.get("chat_template_kwargs", {})
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
**chat_template_kwargs,
)
request["prompt_tokens"] = prompt
vision_message_index += 1
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
self._add_text(prompt_token_ids[vision_start_index:], outputs)
return self._pack_outputs(outputs)
if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx]
meta["thw"] = (t, h, w)
meta["fps"] = outputs["fps"][idx]
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
return outputs
def _add_text(self, tokens, outputs: Dict) -> None:
"""
@@ -312,9 +311,9 @@ class DataProcessor:
outputs["input_ids"].extend(tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
"""
@@ -332,7 +331,7 @@ class DataProcessor:
position = text_index + start_pos
return position
def _add_image(self, img, outputs: Dict) -> None:
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
"""
Add image data to model inputs dictionary.
@@ -349,20 +348,47 @@ class DataProcessor:
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
grid_thw = ret["grid_thw"].tolist()
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(grid_thw)
outputs["image_type_ids"].append(0)
t, h, w = grid_thw
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
def _add_video(self, frames, meta: Dict, outputs: Dict) -> None:
outputs["fps"].append(0)
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // self.image_processor.merge_size**2
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
_, h, w = meta["thw"]
pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["images"].append(img)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)
outputs["fps"].append(0)
def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) -> None:
"""
Add video data to model inputs dictionary.
@@ -380,20 +406,49 @@ class DataProcessor:
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
grid_thw = ret["grid_thw"].tolist()
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(grid_thw)
outputs["image_type_ids"].extend([1] * grid_thw[0])
fps = meta["fps"]
second_per_grid_t = self.temporal_conv_size / fps
t, h, w = grid_thw
position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
outputs["position_ids"].append(position_ids)
outputs["cur_position"] = position_ids.max() + 1
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["fps"].append(fps)
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // self.image_processor.merge_size**2
t, h, w = meta["thw"]
outputs["images"].append(frames)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[t, h, w]]))
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["image_type_ids"].extend([1] * t)
fps = meta["fps"]
second_per_grid_t = self.temporal_conv_size / fps
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["fps"].append(fps)
def _compute_vision_positions(
self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
@@ -441,20 +496,20 @@ class DataProcessor:
- frames: Processed video frames as numpy array
- metadata: Updated video metadata dictionary
"""
frames, meta = read_frames(url)
reader, meta, _ = read_video_decord(url, save_to_disk=False)
# Apply frame sampling if fps or target_frames specified
fps = item.get("fps", None)
num_frames = item.get("target_frames", None)
fps = item.get("fps", self.fps)
num_frames = item.get("target_frames", self.target_frames)
if fps is not None or num_frames is not None:
frame_indices = list(range(meta["num_of_frame"]))
if fps > 0 or num_frames > 0:
# Get frame sampling constraints
min_frames = item.get("min_frames", self.min_frames)
max_frames = item.get("max_frames", self.max_frames)
# Sample frames according to specifications
frames = sample_frames(
video=frames,
frame_indices = sample_frames(
frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size
min_frames=min_frames,
max_frames=max_frames,
@@ -464,42 +519,38 @@ class DataProcessor:
)
# Update metadata with new frame count and fps
meta["num_of_frame"] = frames.shape[0]
meta["num_of_frame"] = len(frame_indices)
if fps is not None:
meta["fps"] = fps # Use specified fps
meta["duration"] = frames.shape[0] / fps
meta["duration"] = len(frame_indices) / fps
else:
meta["fps"] = frames.shape[0] / meta["duration"] # Calculate fps from sampled frames
meta["fps"] = len(frame_indices) / meta["duration"] # Calculate fps from sampled frames
frames = []
for idx in frame_indices:
frame = reader[idx].asnumpy()
image = Image.fromarray(frame, "RGB")
frames.append(image)
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
return frames, meta
def apply_chat_template(self, request):
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
"""
Apply chat template to convert messages into token sequence.
Args:
request: Dictionary containing chat messages
Returns:
List of token IDs
Raises:
ValueError: If model doesn't support chat templates
get cache correspond to given hash values
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
req = pickle.dumps(mm_hashes)
socket.send_multipart([b"", req])
_, resp = socket.recv_multipart()
mm_items = pickle.loads(resp)
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
raw_prompt = self.tokenizer.apply_chat_template(
request["messages"],
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
)
prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "")
request["prompt_tokens"] = raw_prompt
return mm_items
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}"
)
return token_ids
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
"""
update cache data
"""
req = pickle.dumps((mm_hashes, mm_items))
socket.send_multipart([b"", req])
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")

View File

@@ -18,50 +18,9 @@ import math
from typing import Optional, Union
import numpy as np
from PIL import Image
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
def read_frames(video_path):
"""
Read and decode video frames from the given path
This function reads a video file and decodes it into individual RGB frames
using decord video reader. It also extracts video metadata including fps,
duration and frame count.
Args:
video_path (str): Path to the video file or bytes object containing video data
Returns:
tuple: A tuple containing:
frames (numpy.ndarray): Array of shape (num_frames, height, width, 3)
containing decoded RGB video frames
meta (dict): Dictionary containing video metadata:
- fps (float): Frames per second
- duration (float): Video duration in seconds
- num_of_frame (int): Total number of frames
- width (int): Frame width in pixels
- height (int): Frame height in pixels
Note:
- The function uses decord library for efficient video reading
- All frames are converted to RGB format regardless of input format
"""
reader, meta, _ = read_video_decord(video_path, save_to_disk=False)
frames = []
for i in range(meta["num_of_frame"]):
frame = reader[i].asnumpy()
image = Image.fromarray(frame, "RGB")
frames.append(image)
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
return frames, meta
def sample_frames(
video: np.ndarray,
frame_factor: int,
min_frames: int,
max_frames: int,
@@ -73,7 +32,6 @@ def sample_frames(
Sample frames from video according to specified criteria.
Args:
video: Input video frames as numpy array
frame_factor: Ensure sampled frames are multiples of this factor
min_frames: Minimum number of frames to sample
max_frames: Maximum number of frames to sample
@@ -89,18 +47,15 @@ def sample_frames(
or if required metadata is missing,
or if requested frames exceed available frames
"""
if fps is not None and num_frames is not None:
if fps > 0 and num_frames > 0:
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
if fps is None and num_frames is None:
return video
total_num_frames = video.shape[0]
total_num_frames = metadata["num_of_frame"]
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is not None:
if num_frames > 0:
num_frames = round(num_frames / frame_factor) * frame_factor
elif fps is not None:
elif fps > 0:
if metadata is None:
raise ValueError(
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
@@ -110,7 +65,6 @@ def sample_frames(
num_frames = total_num_frames / metadata["fps"] * fps
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
num_frames = math.floor(num_frames / frame_factor) * frame_factor
if num_frames > total_num_frames:
raise ValueError(
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
@@ -118,14 +72,11 @@ def sample_frames(
)
# Calculate frame indices based on sampling strategy
if num_frames is not None:
if num_frames > 0:
# Evenly spaced sampling for target frame count
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32)
else:
# Keep all frames if no sampling requested
indices = np.arange(0, total_num_frames).astype(np.int32)
# Apply frame selection
video = video[indices]
return video
return indices

View File

@@ -47,6 +47,7 @@ class QwenVLProcessor(TextProcessor):
mm_processor_kwargs=None,
reasoning_parser_obj=None,
tool_parser_obj=None,
enable_processor_cache=False,
):
"""
Initialize QwenVLProcessor instance.
@@ -65,6 +66,7 @@ class QwenVLProcessor(TextProcessor):
processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)
self.processor = DataProcessor(
model_path=model_name_or_path,
enable_processor_cache=enable_processor_cache,
tokens_per_second=config.vision_config.tokens_per_second,
tokenizer=self.tokenizer,
**processor_kwargs,
@@ -271,7 +273,7 @@ class QwenVLProcessor(TextProcessor):
return request
def append_completion_tokens(self, outputs, completion_token_ids):
def append_completion_tokens(self, multimodal_inputs, completion_token_ids):
"""
Append completion tokens to existing outputs.
@@ -279,19 +281,14 @@ class QwenVLProcessor(TextProcessor):
outputs: Current model outputs
completion_token_ids: completion tokens to append
"""
out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]}
self.processor._add_text(completion_token_ids, out)
outputs["input_ids"] = np.concatenate(
[outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0
)
outputs["token_type_ids"] = np.concatenate(
[outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0
)
outputs["position_ids"] = np.concatenate(
[outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64
)
outputs["cur_position"] = out["cur_position"]
num_tokens = len(completion_token_ids)
multimodal_inputs["input_ids"].extend(completion_token_ids)
multimodal_inputs["token_type_ids"].extend([0] * num_tokens)
pos_ids = self.processor._compute_text_positions(multimodal_inputs["cur_position"], num_tokens)
multimodal_inputs["position_ids"].append(pos_ids)
multimodal_inputs["cur_position"] += num_tokens
def pack_outputs(self, outputs):
"""
@@ -303,7 +300,24 @@ class QwenVLProcessor(TextProcessor):
Returns:
dict: Packed output dictionary with all required fields
"""
if not outputs["images"]:
outputs["images"] = None # No images case
outputs["grid_thw"] = None # No spatial dimensions
outputs["image_type_ids"] = None # No type IDs
else:
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
# Convert all outputs to numpy arrays with appropriate types
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
outputs["position_ids"] = np.concatenate(
outputs["position_ids"], axis=1, dtype=np.int64
) # Concatenate position ID
outputs["image_patch_id"] = self.processor.image_token_id
outputs["video_patch_id"] = self.processor.video_token_id
outputs["position_ids"] = outputs["position_ids"].transpose(1, 0)
return outputs

View File

@@ -1,7 +1,7 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#

View File

@@ -1,7 +1,7 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#

View File

@@ -0,0 +1,35 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import hashlib
import pickle
import numpy as np
from fastdeploy.utils import data_processor_logger
class MultimodalHasher:
@classmethod
def hash_features(cls, obj: object) -> str:
if isinstance(obj, np.ndarray):
return hashlib.sha256((obj.tobytes())).hexdigest()
data_processor_logger.warning(
f"Unsupported type for hashing features: {type(obj)}" + ", use pickle for serialization"
)
return hashlib.sha256((pickle.dumps(obj))).hexdigest()

View File

@@ -1,7 +1,7 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#

View File

@@ -0,0 +1,36 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
class MultimodalRegistry:
"""
A registry for multimodal models
"""
mm_models: set[str] = {
"Ernie4_5_VLMoeForConditionalGeneration",
"Ernie5MoeForCausalLM",
"Qwen2_5_VLForConditionalGeneration",
"Ernie5ForCausalLM",
"Ernie4_5_VLMoeForProcessRewardModel",
}
@classmethod
def contains_model(cls, name: str) -> bool:
"""
Check if the given name exists in registry.
"""
return name in cls.mm_models

View File

@@ -1,7 +1,7 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#

View File

@@ -1,7 +1,7 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#

View File

@@ -24,13 +24,12 @@ from typing import Dict, List, Optional, Tuple
import crcmod
from redis import ConnectionPool
from fastdeploy import envs
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler import utils
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.scheduler.storage import AdaptedRedis
from fastdeploy.scheduler.workers import Task, Workers
from fastdeploy.utils import scheduler_logger
from fastdeploy.utils import envs, scheduler_logger
class GlobalScheduler:
@@ -534,32 +533,33 @@ class GlobalScheduler:
continue
request: ScheduledRequest = ScheduledRequest.unserialize(serialized_request)
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
remaining_request.append((request_queue_name, serialized_request))
continue
if required_total_blocks > available_blocks:
remaining_request.append((request_queue_name, serialized_request))
continue
if not envs.FD_ENABLE_MAX_PREFILL:
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
if not envs.FD_ENABLE_MAX_PREFILL:
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
remaining_request.append((request_queue_name, serialized_request))
continue
else:
short_partial_requests += 1
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
remaining_request.append((request_queue_name, serialized_request))
continue
else:
short_partial_requests += 1
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
remaining_request.append((request_queue_name, serialized_request))
continue
else:
if current_prefill_tokens > max_num_batched_tokens:
remaining_request.append((request_queue_name, serialized_request))
continue
if current_prefill_tokens > max_num_batched_tokens:
remaining_request.append((request_queue_name, serialized_request))
continue
scheduled_requests.append(request)

View File

@@ -18,10 +18,9 @@ import threading
import time
from typing import Dict, List, Optional, Tuple
from fastdeploy import envs
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.utils import scheduler_logger
from fastdeploy.utils import envs, scheduler_logger
class LocalScheduler:
@@ -247,35 +246,40 @@ class LocalScheduler:
self.wait_request_timeout,
)
required_total_blocks = 0
current_prefill_tokens = 0
requests: List[Request] = []
long_partial_requests, short_partial_requests = 0, 0
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
required_total_blocks = 0
current_prefill_tokens = 0
long_partial_requests, short_partial_requests = 0, 0
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
if not envs.FD_ENABLE_MAX_PREFILL:
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
# 长请求
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
if not envs.FD_ENABLE_MAX_PREFILL:
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
# 长请求
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
break
else:
short_partial_requests += 1
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
break
else:
short_partial_requests += 1
if current_prefill_tokens > max_num_batched_tokens:
break
requests.append(request.raw)
else:
for request_id in batch_ids:
request = self.requests[request_id]
requests.append(request.raw)
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
break
else:
if current_prefill_tokens > max_num_batched_tokens:
break
requests.append(request.raw)
self.ids_read_cursor += len(requests)
if len(batch_ids) > 0 and len(requests) == 0:

View File

@@ -137,6 +137,11 @@ class GPUModelRunner(ModelRunnerBase):
"fused_gemm_epilogue",
]
if self.cache_config.max_encoder_cache > 0:
self.encoder_cache: dict[str, paddle.Tensor] = {}
else:
self.encoder_cache = None
# Sampler
if not self.speculative_decoding:
self.sampler = Sampler(fd_config)
@@ -297,6 +302,185 @@ class GPUModelRunner(ModelRunnerBase):
schemata_key,
)
def get_chunked_inputs(self, req: Request):
"""
Get inputs in current chunk
"""
prefill_start_index = req.prefill_start_index
prefill_end_index = req.prefill_end_index
inputs = req.multimodal_inputs
input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index]
token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index]
image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end]
images = inputs["images"][req.image_start : req.image_end]
grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end]
mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end]
return (
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
)
def batch_uncached_inputs(self, req: Request):
"""
Batch uncached multimodal inputs
"""
(input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req)
image_type_ids_size = grid_thw[:, 0]
image_type_ids_split = np.cumsum(image_type_ids_size)[:-1]
image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0)
images_size = np.prod(grid_thw, axis=1)
images_split = np.cumsum(images_size)[:-1]
images_lst = np.array_split(images, images_split, axis=0)
assert len(image_type_ids_lst) == len(
mm_hashes
), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}"
assert len(images_lst) == len(
mm_hashes
), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}"
uncached_image_type_ids = []
uncached_images = []
uncached_grid_thw = []
uncached_mm_hashes = []
for i, mm_hash in enumerate(mm_hashes):
if mm_hash in self.encoder_cache:
continue
uncached_image_type_ids.append(image_type_ids_lst[i])
uncached_images.append(images_lst[i])
uncached_grid_thw.append(grid_thw[i])
uncached_mm_hashes.append(mm_hash)
uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if len(uncached_mm_hashes) > 0:
uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64)
uncached_images = paddle.to_tensor(
np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
)
uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64)
return (
uncached_input_ids,
uncached_token_type_ids,
uncached_image_type_ids,
uncached_images,
uncached_grid_thw,
uncached_mm_hashes,
)
def scatter_and_cache_features(self, image_features, inputs):
"""
Split batched image features and cache them
"""
merge_size = 2
grid_thw = inputs["grid_thw"]
mm_hashes = inputs["mm_hashes"]
image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist()
image_features_lst = paddle.split(image_features, image_features_size, axis=0)
assert len(image_features_lst) == len(
mm_hashes
), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}"
for i, mm_hash in enumerate(mm_hashes):
self.encoder_cache[mm_hash] = image_features_lst[i].cpu()
def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict):
"""
Apply multimodal inputs to share_inputs
- add image_features, extract and cache vision features from model
- add rope_emb, rotate position embeddings
"""
if self.encoder_cache:
evict_mm_hashes = request.get("evict_mm_hashes", None)
if evict_mm_hashes:
for mm_hash in evict_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
inputs = request.multimodal_inputs
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = inputs
if self.encoder_cache:
(
vision_inputs["input_ids"],
vision_inputs["token_type_ids"],
vision_inputs["image_type_ids"],
vision_inputs["images"],
vision_inputs["grid_thw"],
vision_inputs["mm_hashes"],
) = self.batch_uncached_inputs(request)
if len(vision_inputs["mm_hashes"]) > 0:
# uncached multimodal inputs exist
image_features = self.extract_vision_features(vision_inputs)
self.scatter_and_cache_features(image_features, vision_inputs)
full_image_features_lst = []
for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]:
feature = self.encoder_cache[mm_hash].cuda()
full_image_features_lst.append(feature)
image_features = paddle.concat(full_image_features_lst, axis=0)
else:
(
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
) = self.get_chunked_inputs(request)
vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64)
vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
vision_inputs["images"] = paddle.to_tensor(
images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
)
vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64)
vision_inputs["mm_hashes"] = mm_hashes
image_features = self.extract_vision_features(vision_inputs)
# part of the first image may be already cached
if "ernie" in self.model_config.model_type:
actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id)
elif "qwen" in self.model_config.model_type:
actual_image_token_num = paddle.sum(
vision_inputs["input_ids"] == vision_inputs["image_patch_id"]
) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"])
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
self.share_inputs["image_features"] = image_features[-actual_image_token_num:]
else:
self.share_inputs["image_features"] = None
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(request.idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
@@ -326,51 +510,7 @@ class GPUModelRunner(ModelRunnerBase):
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
if self.enable_mm:
inputs = request.multimodal_inputs
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
else:
self.share_inputs["image_features"] = None
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids)
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking

View File

@@ -653,6 +653,12 @@ def parse_args():
help="Flag to specify dtype of lm_head as FP32",
)
parser.add_argument(
"--max_encoder_cache",
type=int,
help="Maximum encoder cache tokens(use 0 to disable).",
)
parser.add_argument(
"--cache-transfer-protocol",
type=str,

View File

@@ -91,17 +91,18 @@ class TestQwenVLProcessor(unittest.TestCase):
config.vision_config.tokens_per_second = 2
self.patcher_parse_image = patch(
"fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_image", return_value=mock_pil_image(480, 640)
"fastdeploy.entrypoints.chat_utils.MultimodalPartParser.parse_image", return_value=mock_pil_image(480, 640)
)
self.patcher_parse_image.start()
self.patcher_parse_video = patch(
"fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_video", return_value=b"123"
"fastdeploy.entrypoints.chat_utils.MultimodalPartParser.parse_video", return_value=b"123"
)
self.patcher_parse_video.start()
self.patcher_read_frames = patch(
"fastdeploy.input.qwen_vl_processor.process.read_frames", return_value=mock_read_frames(480, 640, 5, 2)
"fastdeploy.input.qwen_vl_processor.process.DataProcessor._load_and_process_video",
return_value=mock_read_frames(480, 640, 5, 2),
)
self.patcher_read_frames.start()
@@ -163,8 +164,6 @@ class TestQwenVLProcessor(unittest.TestCase):
self.assertEqual(
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
)
self.assertEqual(result.multimodal_inputs["pic_cnt"], 1)
self.assertEqual(result.multimodal_inputs["video_cnt"], 1)
def test_process_request_dict(self):
"""
@@ -204,8 +203,6 @@ class TestQwenVLProcessor(unittest.TestCase):
self.assertEqual(
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
)
self.assertEqual(result["multimodal_inputs"]["pic_cnt"], 1)
self.assertEqual(result["multimodal_inputs"]["video_cnt"], 1)
def test_prompt(self):
"""
@@ -240,8 +237,6 @@ class TestQwenVLProcessor(unittest.TestCase):
self.assertEqual(
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
)
self.assertEqual(result.multimodal_inputs["pic_cnt"], 1)
self.assertEqual(result.multimodal_inputs["video_cnt"], 1)
def test_message_and_prompt(self):
"""

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import pickle
import unittest
import numpy as np
from fastdeploy.multimodal.hasher import MultimodalHasher
class TestHashFeatures(unittest.TestCase):
def test_hash_features_ndarray(self):
"""Test hash features with numpy ndarray"""
arr = np.random.randint(low=0, high=255, size=(28, 28), dtype=np.uint8)
arr_hash = MultimodalHasher.hash_features(arr)
target_hash = hashlib.sha256((arr.tobytes())).hexdigest()
assert arr_hash == target_hash, f"Ndarray hash mismatch: {arr_hash} != {target_hash}"
def test_hash_features_object(self):
"""Test hash features with unsupported object type"""
obj = {"key": "value"}
obj_hash = MultimodalHasher.hash_features(obj)
target_hash = hashlib.sha256((pickle.dumps(obj))).hexdigest()
assert obj_hash == target_hash, f"Dict hash mismatch: {obj_hash} != {target_hash}"
obj = "test hasher str"
obj_hash = MultimodalHasher.hash_features(obj)
target_hash = hashlib.sha256((pickle.dumps(obj))).hexdigest()
assert obj_hash == target_hash, f"Str hash mismatch: {obj_hash} != {target_hash}"
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,82 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastdeploy.cache_manager.multimodal_cache_manager import EncoderCacheManager
from fastdeploy.engine.request import ImagePosition
def test_mm_encoder_cache():
max_encoder_cache = 4096
encoder_cache = EncoderCacheManager(max_encoder_cache=max_encoder_cache)
mm_hashes = ["mm_hash1", "mm_hash2"]
mm_positions = [ImagePosition(offset=120, length=400), ImagePosition(offset=620, length=800)]
cache_length = mm_positions[0].length + mm_positions[1].length
evict_hashes = encoder_cache.apply_cache(mm_hashes=mm_hashes, mm_items=mm_positions)
assert evict_hashes == [], "The evicted hashes should be empty."
assert list(encoder_cache.cache.keys()) == [
"mm_hash1",
"mm_hash2",
], "The cache should contain mm_hash1 and mm_hash2."
assert (
encoder_cache.current_cache_size == cache_length
), "The cache size should be the sum of the lengths of mm_hash1 and mm_hash2."
assert (
encoder_cache.current_cache_size <= max_encoder_cache
), "The cache size should be less than or equal to the max_encoder_cache."
mm_hashes = ["mm_hash3", "mm_hash4"]
mm_positions = [ImagePosition(offset=20, length=1204), ImagePosition(offset=1800, length=2048)]
cache_length += mm_positions[0].length + mm_positions[1].length - 400
evict_hashes = encoder_cache.apply_cache(mm_hashes=mm_hashes, mm_items=mm_positions)
assert evict_hashes == ["mm_hash1"], "The evicted hashes should be mm_hash1."
assert list(encoder_cache.cache.keys()) == [
"mm_hash2",
"mm_hash3",
"mm_hash4",
], "The cache should contain mm_hash2, mm_hash3, and mm_hash4."
assert (
encoder_cache.current_cache_size == cache_length
), "The cache size should be the sum of the lengths of mm_hash2, mm_hash3, and mm_hash4."
assert (
encoder_cache.current_cache_size <= max_encoder_cache
), "The cache size should be less than or equal to the max_encoder_cache."
evict_hashes = encoder_cache.apply_cache(mm_hashes=["mm_hash2"], mm_items=[ImagePosition(offset=620, length=800)])
assert evict_hashes == [], "The evicted hashes should be empty."
assert (
encoder_cache.current_cache_size == cache_length
), "The cache size should be the sum of the lengths of mm_hash2, mm_hash3, and mm_hash4."
assert (
encoder_cache.current_cache_size <= max_encoder_cache
), "The cache size should be less than or equal to the max_encoder_cache."
cache_length -= 1204
evict_hashes = encoder_cache.evict_cache(needed=800)
assert evict_hashes == ["mm_hash3"], "The evicted hashes should be mm_hash3."
assert list(encoder_cache.cache.keys()) == [
"mm_hash4",
"mm_hash2",
], "The cache should contain mm_hash2 and mm_hash4."
assert (
encoder_cache.current_cache_size == cache_length
), "The cache size should be the sum of the lengths of mm_hash2 and mm_hash4."
assert (
encoder_cache.current_cache_size <= max_encoder_cache
), "The cache size should be less than or equal to the max_encoder_cache."
encoder_cache.clear_cache()
assert encoder_cache.current_cache_size == 0, "The cache size should be 0."
assert list(encoder_cache.cache.keys()) == [], "The cache should be empty."

View File

@@ -0,0 +1,259 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict
from types import SimpleNamespace
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.request import ImagePosition, Request
from fastdeploy.scheduler import SchedulerConfig
def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_override=100, max_num_batched_tokens=3200):
engine_args = EngineArgs(
max_num_seqs=max_num_seqs,
num_gpu_blocks_override=num_gpu_blocks_override,
max_num_batched_tokens=max_num_batched_tokens,
)
args = asdict(engine_args)
cache_cfg = CacheConfig(args)
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=8192)
speculative_cfg = SimpleNamespace(method=None)
model_cfg.print = print
cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args)
scheduler_cfg = SchedulerConfig(args)
graph_opt_cfg = engine_args.create_graph_optimization_config()
fd_config = FDConfig(
model_config=model_cfg,
cache_config=cache_cfg,
parallel_config=parallel_cfg,
graph_opt_config=graph_opt_cfg,
speculative_config=speculative_cfg,
scheduler_config=scheduler_cfg,
)
return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed")
def test_normal_case():
block_size = 64
cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=False, num_gpu_blocks_override=100)
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
req2 = Request.from_dict(
{"request_id": "req2", "prompt_token_ids": [1] * 1600 + [2] * 1600, "prompt_token_ids_len": 3200}
)
req3 = Request.from_dict(
{"request_id": "req3", "prompt_token_ids": [1] * 1600 + [3] * 1600, "prompt_token_ids_len": 3200}
)
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size)
assert len(common_block_ids) == 0
assert matched_token_num == 0
assert len(cache_manager.gpu_free_block_list) == 100
req1.block_tables.extend(common_block_ids)
# allocate for req1 inputs
num_new_block = 50
req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
req1.num_computed_tokens += 50 * block_size
cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens)
assert len(cache_manager.gpu_free_block_list) == 50
# allocate for req2 inputs
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size)
assert len(common_block_ids) == 25
assert matched_token_num == 25 * block_size
req2.num_cached_tokens = matched_token_num
req2.num_computed_tokens = 25 * block_size
num_new_block = 25
req2.block_tables.extend(common_block_ids)
req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens)
# allocate for req3 input
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size)
assert len(common_block_ids) == 25
assert matched_token_num == 25 * block_size
req3.num_cached_tokens = matched_token_num
req3.num_computed_tokens = 25 * block_size
assert len(cache_manager.gpu_free_block_list) == 25
req3.block_tables.extend(common_block_ids)
num_new_block = 25
assert cache_manager.can_allocate_gpu_blocks(num_new_block)
req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens)
assert len(cache_manager.gpu_free_block_list) == 0
def test_mm_extra_keys():
block_size = 64
cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True)
prompt_token_ids = [1] * 100 + [2] * 100
req1 = {
"request_id": "req1",
"prompt_token_ids": prompt_token_ids,
"prompt_token_ids_len": len(prompt_token_ids),
}
for idx in range(0, len(prompt_token_ids), block_size):
token_ids_lens = min(block_size, len(prompt_token_ids[idx:]))
mm_idx, extra_keys = cache_manager.get_block_hash_extra_keys(
request=Request.from_dict(req1),
start_idx=idx,
end_idx=idx + token_ids_lens,
mm_idx=0,
)
assert extra_keys == [], f"extra_keys {extra_keys} != [], start_idx {idx}, end_idx {idx + token_ids_lens}"
assert mm_idx == 0, f"mm_idx {mm_idx} != 0, start_idx {idx}, end_idx {idx + token_ids_lens}"
# block 1
prompt_token_ids = [1] * 30 + [-1] * 34
mm_positions = [ImagePosition(offset=30, length=80)]
mm_hashes = ["image1"]
extra_keys_list = [(0, ["image1"])]
# block 2
prompt_token_ids += [-1] * 46 + [2] * 18
extra_keys_list.append((1, ["image1"]))
# block 3
prompt_token_ids += [-1] * 100
mm_positions.append(ImagePosition(offset=128, length=100))
mm_hashes.append("image2")
extra_keys_list.append((1, ["image2"]))
# block 4、5
prompt_token_ids += [3] * 40
extra_keys_list.append((1, ["image2"]))
extra_keys_list.append((1, []))
req2 = {
"request_id": "req2",
"prompt_token_ids": prompt_token_ids,
"prompt_token_ids_len": len(prompt_token_ids),
"multimodal_inputs": {
"mm_positions": mm_positions,
"mm_hashes": mm_hashes,
},
}
mm_idx, key_idx = 0, 0
for idx in range(0, len(prompt_token_ids), block_size):
token_ids_lens = min(block_size, len(prompt_token_ids[idx:]))
mm_idx, extra_keys = cache_manager.get_block_hash_extra_keys(
request=Request.from_dict(req2),
start_idx=idx,
end_idx=idx + token_ids_lens,
mm_idx=mm_idx,
)
target_idx, target_keys = extra_keys_list[key_idx]
assert (
mm_idx == target_idx
), f"mm_idx {mm_idx} != target_idx {target_idx}, start_idx {idx}, end_idx {idx + token_ids_lens}"
assert (
extra_keys == target_keys
), f"extra_keys {extra_keys} != target_keys {target_keys}, start_idx {idx}, end_idx {idx + token_ids_lens}"
key_idx += 1
def test_mm_prefix_cache():
block_size = 64
cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100)
multimodal_inputs = {
"mm_positions": [ImagePosition(offset=120, length=1200)],
"mm_hashes": ["image1"],
}
req1_dict = {
"request_id": "req1",
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120,
"prompt_token_ids_len": 1440,
"multimodal_inputs": multimodal_inputs,
}
req1 = Request.from_dict(req1_dict)
multimodal_inputs = dict(multimodal_inputs)
multimodal_inputs["mm_positions"].append(ImagePosition(offset=1836, length=587))
multimodal_inputs["mm_hashes"].append("image2")
req2_dict = {
"request_id": "req2",
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 396 + [-1] * 587,
"prompt_token_ids_len": 2423,
"multimodal_inputs": multimodal_inputs,
}
req2 = Request.from_dict(req2_dict)
multimodal_inputs = dict(multimodal_inputs)
multimodal_inputs["mm_hashes"] = ["image3", "image4"]
req3_dict = {
"request_id": "req3",
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 396 + [-1] * 587,
"prompt_token_ids_len": 2423,
"multimodal_inputs": multimodal_inputs,
}
req3 = Request.from_dict(req3_dict)
multimodal_inputs = dict(multimodal_inputs)
multimodal_inputs["mm_positions"] = [ImagePosition(offset=120, length=1200)]
multimodal_inputs["mm_hashes"] = ["image3"]
req4_dict = {
"request_id": "req4",
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 352,
"prompt_token_ids_len": 1792,
"multimodal_inputs": multimodal_inputs,
}
req4 = Request.from_dict(req4_dict)
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size)
assert len(common_block_ids) == 0
assert matched_token_num == 0
assert len(cache_manager.gpu_free_block_list) == 100
req1.block_tables.extend(common_block_ids)
# allocate for req1 inputs
num_new_block = 22
req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
req1.num_computed_tokens += 22 * block_size
cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens)
assert len(cache_manager.gpu_free_block_list) == 78
# allocate for req2 inputs
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size)
assert len(common_block_ids) == 22
assert matched_token_num == 22 * block_size
req2.num_cached_tokens = matched_token_num
req2.num_computed_tokens = matched_token_num
num_new_block = 15
req2.block_tables.extend(common_block_ids)
req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
req2.num_computed_tokens += 15 * block_size
cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens)
# allocate for req3 input
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size)
assert len(common_block_ids) == 1
assert matched_token_num == 1 * block_size
req3.num_cached_tokens = matched_token_num
req3.num_computed_tokens = matched_token_num
assert len(cache_manager.gpu_free_block_list) == 63
req3.block_tables.extend(common_block_ids)
num_new_block = 36
assert cache_manager.can_allocate_gpu_blocks(num_new_block)
req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
req3.num_computed_tokens += 36 * block_size
cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens)
assert len(cache_manager.gpu_free_block_list) == 27
# allocate for req4 input
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req4, block_size)
assert len(common_block_ids) == 28
assert matched_token_num == 28 * block_size

View File

@@ -1,73 +0,0 @@
from dataclasses import asdict
from types import SimpleNamespace
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.request import Request
def test_normal_case():
max_num_seqs = 3
block_size = 64
engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=100, max_num_batched_tokens=3200)
args = asdict(engine_args)
cache_cfg = CacheConfig(args)
model_cfg = SimpleNamespace(enable_mm=False)
speculative_cfg = SimpleNamespace(method=None)
model_cfg.print = print
model_cfg.max_model_len = 5120
cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args)
scheduler_cfg = SchedulerConfig(args)
graph_opt_cfg = engine_args.create_graph_optimization_config()
fd_config = FDConfig(
model_config=model_cfg,
cache_config=cache_cfg,
parallel_config=parallel_cfg,
graph_opt_config=graph_opt_cfg,
speculative_config=speculative_cfg,
scheduler_config=scheduler_cfg,
)
cache_manager = PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed")
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
req2 = Request.from_dict(
{"request_id": "req2", "prompt_token_ids": [1] * 1600 + [2] * 1600, "prompt_token_ids_len": 3200}
)
req3 = Request.from_dict(
{"request_id": "req3", "prompt_token_ids": [1] * 1600 + [3] * 1600, "prompt_token_ids_len": 3200}
)
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size)
assert len(common_block_ids) == 0
assert matched_token_num == 0
assert len(cache_manager.gpu_free_block_list) == 100
req1.block_tables.extend(common_block_ids)
# allocate for req1 inputs
num_new_block = 50
req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
req1.num_computed_tokens += 50 * block_size
cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens)
assert len(cache_manager.gpu_free_block_list) == 50
# allocate for req2 inputs
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size)
assert len(common_block_ids) == 25
assert matched_token_num == 25 * block_size
req2.num_cached_tokens = matched_token_num
req2.num_computed_tokens == 25 * block_size
num_new_block = 25
req2.block_tables.extend(common_block_ids)
req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens)
# allocate for req3 input
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size)
assert len(common_block_ids) == 25
assert matched_token_num == 25 * block_size
req3.num_cached_tokens = matched_token_num
req3.num_computed_tokens == 25 * block_size
assert len(cache_manager.gpu_free_block_list) == 25
req3.block_tables.extend(common_block_ids)
num_new_block = 25
assert cache_manager.can_allocate_gpu_blocks(num_new_block)
req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens)
assert len(cache_manager.gpu_free_block_list) == 0