diff --git a/docs/parameters.md b/docs/parameters.md index 0de0effa0..7a41dac37 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -55,6 +55,8 @@ When using FastDeploy to deploy models (including offline inference and service | ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. | | ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. | | ```load_choices``` | `str` | By default, the "default" loader is used for weight loading. To load Torch weights or enable weight acceleration, "default_v1" must be used.| +| ```max_encoder_cache``` | `int` | Maximum number of tokens in the encoder cache (use 0 to disable). | +| ```max_processor_cache``` | `int` | Maximum number of bytes(in GiB) in the processor cache (use 0 to disable). | ## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```? diff --git a/docs/usage/log.md b/docs/usage/log.md index 1c2dbf90c..5c963a065 100644 --- a/docs/usage/log.md +++ b/docs/usage/log.md @@ -13,7 +13,7 @@ By default, logs are stored in the `log` directory under the execution path. To * `fastdeploy.log` : Records configuration information during instance startup, as well as request and response details during runtime. * `workerlog.*` : Tracks model loading progress and inference operator errors. Each GPU card has a corresponding file. * `worker_process.log` : Logs engine inference data for each iteration. -* `prefix_cache_manager.log` : Records KV Cache logical index allocation for each request and cache hit status. +* `cache_manager.log` : Records KV Cache logical index allocation for each request and cache hit status. * `launch_worker.log` : Logs model startup information and error messages. * `gpu_worker.log` : Records KV Cache block count information during profiling. * `gpu_model_runner.log` : Contains model details and loading time. diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index 1a374281f..859685ff6 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -53,6 +53,8 @@ | ```tool_call_parser``` | `str` | 指定要使用的function call解析器,以便从模型输出中抽取 function call内容| | ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径,以便注册不在代码库中的parser,parser中代码格式需遵循代码库中格式| | ```load_choices``` | `str` | 默认使用"default" loader进行权重加载,加载torch权重/权重加速需开启 "default_v1"| +| ```max_encoder_cache``` | `int` | 编码器缓存的最大token数(使用0表示禁用)。 | +| ```max_processor_cache``` | `int` | 处理器缓存的最大字节数(以GiB为单位,使用0表示禁用)。 | ## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系? diff --git a/docs/zh/usage/log.md b/docs/zh/usage/log.md index 4c54cc7a0..7d276655b 100644 --- a/docs/zh/usage/log.md +++ b/docs/zh/usage/log.md @@ -13,7 +13,7 @@ FastDeploy 在部署过程中,会产生如下日志文件,各日志含义说 * `fastdeploy.log` : 记录当前实例启动的各个 config 的信息,运行中记录用户请求的 request 及 response 信息 * `workerlog.*` : 记录模型启动加载进度及推理算子报错信息,每个卡对应一个文件 * `worker_process.log` : 记录引擎每一轮推理的数据 -* `prefix_cache_manager.log` : 记录每一个请求分配 KV Cache 的逻辑索引,以及当前请求的命中情况 +* `cache_manager.log` : 记录每一个请求分配 KV Cache 的逻辑索引,以及当前请求的命中情况 * `launch_worker.log` : 记录模型启动信息及报错信息 * `gpu_worker.log` : 记录 profile 时计算 KV Cache block 数目的信息 * `gpu_model_runner.log` : 当前的模型信息及加载时间 diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index dc8ef406d..631f5efb0 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -18,7 +18,7 @@ from enum import Enum from fastdeploy.utils import get_logger -logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") +logger = get_logger("prefix_cache_manager", "cache_manager.log") DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = { diff --git a/fastdeploy/cache_manager/cache_metrics.py b/fastdeploy/cache_manager/cache_metrics.py index 3946357c8..a7dec1bd6 100644 --- a/fastdeploy/cache_manager/cache_metrics.py +++ b/fastdeploy/cache_manager/cache_metrics.py @@ -17,7 +17,7 @@ from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import get_logger -logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") +logger = get_logger("prefix_cache_manager", "cache_manager.log") class CacheMetrics: diff --git a/fastdeploy/cache_manager/multimodal_cache_manager.py b/fastdeploy/cache_manager/multimodal_cache_manager.py new file mode 100644 index 000000000..febce1bc2 --- /dev/null +++ b/fastdeploy/cache_manager/multimodal_cache_manager.py @@ -0,0 +1,163 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import pickle +import threading +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any, Tuple + +import numpy as np +import zmq + +from fastdeploy import envs +from fastdeploy.engine.request import ImagePosition +from fastdeploy.utils import get_logger + +logger = get_logger("prefix_cache_manager", "cache_manager.log") + + +class MultimodalLRUCache(ABC): + """ + General lru cache for multimodal data + """ + + def __init__(self, max_cache_size): + self.cache = OrderedDict() + self.current_cache_size = 0 + self.max_cache_size = max_cache_size + + def apply_cache(self, mm_hashes: list[str], mm_items: list[Any]) -> list[str]: + """ + apply data cache, return evicted data + """ + assert len(mm_hashes) == len(mm_items), "mm_hashes and mm_items should have same length" + + evicted_hashes = [] + for idx in range(len(mm_hashes)): + if mm_hashes[idx] in self.cache: + self.cache.move_to_end(mm_hashes[idx]) + else: + item_size = self.get_item_size(mm_items[idx]) + if self.current_cache_size + item_size >= self.max_cache_size: + if item_size > self.max_cache_size: + # cannot be inserted even if we clear all cached data, skip it directly + continue + needed = item_size - (self.max_cache_size - self.current_cache_size) + evicted_hashes.extend(self.evict_cache(needed)) + self.cache[mm_hashes[idx]] = mm_items[idx] + self.current_cache_size += item_size + + return evicted_hashes + + def evict_cache(self, needed: int) -> list[str]: + """ + evict data cache with needed size + """ + reduced_size, evicted_hashes = 0, [] + while reduced_size < needed and len(self.cache): + mm_hash, mm_item = self.cache.popitem(last=False) + evicted_hashes.append(mm_hash) + reduced_size += self.get_item_size(mm_item) + self.current_cache_size -= self.get_item_size(mm_item) + + return evicted_hashes + + def get_cache(self, mm_hashes: list[str]) -> list[Any]: + """ + get cached data correspond to given hash values + """ + mm_items = [] + for mm_hash in mm_hashes: + if mm_hash not in self.cache: + mm_items.append(None) + continue + mm_items.append(self.cache[mm_hash]) + + return mm_items + + def clear_cache(self): + """ + clear all cached data + """ + evicted_hashes = list(self.cache.keys()) + self.cache.clear() + self.current_cache_size = 0 + + return evicted_hashes + + @abstractmethod + def get_item_size(self, item: Any) -> int: + raise NotImplementedError("Subclasses must define how to get size of an item") + + +class EncoderCacheManager(MultimodalLRUCache): + """ + EncoderCacheManager is used to cache image features + """ + + def __init__(self, max_encoder_cache): + super().__init__(max_encoder_cache) + + def get_item_size(self, item: ImagePosition) -> int: + return item.length + + +class ProcessorCacheManager(MultimodalLRUCache): + """ + ProcessorCacheManager is used to cache processed data + """ + + def __init__(self, max_processor_cache): + super().__init__(max_processor_cache) + + self.context = zmq.Context() + + self.router = self.context.socket(zmq.ROUTER) + self.router.setsockopt(zmq.SNDHWM, int(envs.FD_ZMQ_SNDHWM)) + self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) + self.router.setsockopt(zmq.SNDTIMEO, -1) + self.router.bind("ipc:///dev/shm/processor_cache.ipc") + + self.poller = zmq.Poller() + self.poller.register(self.router, zmq.POLLIN) + + self.handler_thread = threading.Thread(target=self.cache_request_handler, daemon=True) + self.handler_thread.start() + + def get_item_size(self, item: Tuple[np.ndarray, dict]) -> int: + return item[0].nbytes + + def cache_request_handler(self): + try: + while True: + events = dict(self.poller.poll()) + + if self.router in events: + client, _, content = self.router.recv_multipart() + req = pickle.loads(content) + + if isinstance(req, tuple): + # apply cache request, in format of (mm_hashes, mm_items) + self.apply_cache(req[0], req[1]) + logger.info(f"Apply processor cache of mm_hashes: {req[0]}") + else: + # get cache request + resp = self.get_cache(req) + logger.info(f"Get processor cache of mm_hashes: {req}") + self.router.send_multipart([client, b"", pickle.dumps(resp)]) + except Exception as e: + logger.error(f"Error happened while handling processor cache request: {e}") diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 8934bb7f1..1d5dc9c33 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -14,8 +14,10 @@ # limitations under the License. """ +import hashlib import heapq import os +import pickle import subprocess import sys import threading @@ -35,7 +37,7 @@ from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTre from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import get_logger -logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") +logger = get_logger("prefix_cache_manager", "cache_manager.log") class PrefixCacheManager: @@ -575,31 +577,18 @@ class PrefixCacheManager: """ try: req_id = task.request_id - block_tables = task.block_tables - last_node, num_cached_tokens = self.cache_info[req_id] - if isinstance(task.prompt_token_ids, np.ndarray): - prompt_token_ids = task.prompt_token_ids.tolist() - else: - prompt_token_ids = task.prompt_token_ids - input_ids = prompt_token_ids + task.output_token_ids can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size - left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens] - gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :] if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later self.leaf_req_map[last_node].remove(req_id) with self.request_release_lock: - current_time = time.time() - leaf_node = self.build_path( - req_id=req_id, - current_time=current_time, - input_ids=input_ids, - left_input_ids=left_input_ids, - gpu_block_ids=gpu_extra_block_ids, + leaf_node = self.mm_build_path( + request=task, + num_computed_tokens=num_computed_tokens, block_size=block_size, last_node=last_node, - reverved_dec_block_num=0, + num_cached_tokens=num_cached_tokens, ) self.req_leaf_map[req_id] = leaf_node self.leaf_req_map[leaf_node].add(req_id) @@ -636,10 +625,9 @@ class PrefixCacheManager: prompt_token_ids = task.prompt_token_ids.tolist() else: prompt_token_ids = task.prompt_token_ids - input_ids = prompt_token_ids + task.output_token_ids req_id = task.request_id logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}") - input_token_num = len(input_ids) + input_token_num = len(prompt_token_ids + task.output_token_ids) common_block_ids = [] # 1. match block ( @@ -649,7 +637,7 @@ class PrefixCacheManager: match_block_node, gpu_match_token_num, cpu_match_token_num, - ) = self.match_block(req_id, input_ids, block_size) + ) = self.mm_match_block(task, block_size) # update matched node info self._update_matched_node_info(req_id, match_block_node, current_time=time.time()) @@ -1145,6 +1133,173 @@ class PrefixCacheManager: """ return hash(tuple(block)) + def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx): + """ + Retrieves additional hash keys for block identification. + + Args: + request: The input request object containing the data to be processed. + start_idx (int): The starting index of the block segment to hash. + end_idx (int): The ending index of the block segment to hash. + mm_idx: The multimodal index identifier for specialized content handling. + + Returns: + mm_idx: next multimodal index + hash_keys: A list of additional hash keys + """ + hash_keys = [] + mm_inputs = request.multimodal_inputs + if ( + mm_inputs is None + or "mm_positions" not in mm_inputs + or "mm_hashes" not in mm_inputs + or len(mm_inputs["mm_positions"]) == 0 + ): + return mm_idx, hash_keys + + assert start_idx < end_idx, f"start_idx {start_idx} >= end_idx {end_idx}" + assert ( + start_idx >= 0 and start_idx < request.num_total_tokens + ), f"start_idx {start_idx} out of range {request.num_total_tokens}" + assert ( + end_idx >= 0 and end_idx <= request.num_total_tokens + ), f"end_idx {end_idx} out of range {request.num_total_tokens}" + assert len(mm_inputs["mm_positions"]) == len( + mm_inputs["mm_hashes"] + ), f"mm_positions {len(mm_inputs['mm_positions'])} != mm_hashes {len(mm_inputs['mm_hashes'])}" + assert mm_idx >= 0 and mm_idx < len( + mm_inputs["mm_hashes"] + ), f"mm_idx {mm_idx} out of range {len(mm_inputs['mm_hashes'])}" + + if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length < start_idx: + # non images in current block + return mm_idx, hash_keys + + for img_idx in range(mm_idx, len(mm_inputs["mm_positions"])): + image_offset = mm_inputs["mm_positions"][img_idx].offset + image_length = mm_inputs["mm_positions"][img_idx].length + + if image_offset + image_length < start_idx: + # image before block + continue + elif image_offset >= end_idx: + # image after block + return img_idx, hash_keys + elif image_offset + image_length > end_idx: + hash_keys.append(mm_inputs["mm_hashes"][img_idx]) + return img_idx, hash_keys + else: + hash_keys.append(mm_inputs["mm_hashes"][img_idx]) + return len(mm_inputs["mm_positions"]) - 1, hash_keys + + def hash_block_features(self, input_ids, extra_keys: list = []): + """ + calculate hash value of a block with additional keys + + Args: + input_ids: Input token IDs + extra_keys: Additional keys for block identification + """ + return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest() + + def mm_match_block(self, request, block_size): + """ + Match and retrieve cached blocks for multimodal requests using a radix tree structure. + + Args: + request: The multimodal request object containing prompt and output token IDs. + block_size (int): The size of each token block for matching and processing. + + Returns: + tuple: A tuple containing: + - match_gpu_block_ids (list): List of block IDs matched in GPU cache + - match_cpu_block_ids (list): List of block IDs matched in CPU cache + - swap_node_ids (list): List of node IDs scheduled for GPU-CPU swapping + - current_match_node: The last matched node in the radix tree traversal + - gpu_match_token_num (int): Total number of tokens matched in GPU cache + - cpu_match_token_num (int): Total number of tokens matched in CPU cache + """ + if isinstance(request.prompt_token_ids, np.ndarray): + prompt_token_ids = request.prompt_token_ids.tolist() + else: + prompt_token_ids = request.prompt_token_ids + input_ids = prompt_token_ids + request.output_token_ids + total_token_num = len(input_ids) + current_match_node = self.radix_tree_root # 从根节点开始搜 + match_gpu_block_ids = [] + match_cpu_block_ids = [] + match_node_ids = [] + mm_idx = 0 + match_token_num = 0 + cpu_match_token_num = 0 + gpu_match_token_num = 0 + swap_node_ids = [] + matche_nodes = [] + has_modified_gpu_lru_leaf_heap = False + has_modified_cpu_lru_leaf_heap = False + + with self.cache_status_lock: + while match_token_num < total_token_num: + token_block = input_ids[match_token_num : match_token_num + block_size] + token_num = len(token_block) + if token_num != block_size: + break + mm_idx, extra_keys = self.get_block_hash_extra_keys( + request=request, + start_idx=match_token_num, + end_idx=match_token_num + block_size, + mm_idx=mm_idx, + ) + hash_value = self.hash_block_features(token_block, extra_keys) + if hash_value in current_match_node.children: + child = current_match_node.children[hash_value] + matche_nodes.append(child) + match_node_ids.append(child.node_id) + if child in self.gpu_lru_leaf_set: + self.gpu_lru_leaf_set.remove(child) + self.gpu_lru_leaf_heap.remove(child) + has_modified_gpu_lru_leaf_heap = True + elif child in self.cpu_lru_leaf_set: + self.cpu_lru_leaf_set.remove(child) + self.cpu_lru_leaf_heap.remove(child) + has_modified_cpu_lru_leaf_heap = True + if child.has_in_gpu: + match_gpu_block_ids.append(child.block_id) + gpu_match_token_num += block_size + else: + if child.cache_status == CacheStatus.SWAP2CPU: + logger.info( + f"match_block: req_id {request.request_id} matched node" + + f" {child.node_id} which is being SWAP2CPU" + ) + child.cache_status = CacheStatus.GPU + match_gpu_block_ids.append(child.block_id) + gpu_match_token_num += block_size + elif child.cache_status == CacheStatus.CPU: + child.cache_status = CacheStatus.SWAP2GPU + match_cpu_block_ids.append(child.block_id) + cpu_match_token_num += block_size + swap_node_ids.append(child.node_id) + match_token_num = match_token_num + block_size + current_match_node = child + else: + break + + if has_modified_gpu_lru_leaf_heap: + heapq.heapify(self.gpu_lru_leaf_heap) + if has_modified_cpu_lru_leaf_heap: + heapq.heapify(self.cpu_lru_leaf_heap) + + logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}") + return ( + match_gpu_block_ids, + match_cpu_block_ids, + swap_node_ids, + current_match_node, + gpu_match_token_num, + cpu_match_token_num, + ) + def match_block(self, req_id, input_ids, block_size): """ Args: @@ -1241,6 +1396,86 @@ class PrefixCacheManager: node.req_id_set.add(req_id) node = node.parent + def mm_build_path(self, request, num_computed_tokens, block_size, last_node, num_cached_tokens): + """ + Constructs a caching path in radix tree for multimodal requests by processing computed tokens. + + Args: + request: The inference request object containing: + - prompt_token_ids: Original input tokens (List[int] or np.ndarray) + - output_token_ids: Generated tokens (List[int]) + - mm_positions: Optional image positions for multimodal content + num_computed_tokens: Total tokens processed so far (cached + newly computed) + block_size: Fixed size of token blocks (must match cache configuration) + last_node: The deepest existing BlockNode in the radix tree for this request + num_cached_tokens: Number of tokens already cached + + Returns: + BlockNode: The new deepest node in the constructed path + """ + if isinstance(request.prompt_token_ids, np.ndarray): + prompt_token_ids = request.prompt_token_ids.tolist() + else: + prompt_token_ids = request.prompt_token_ids + input_ids = prompt_token_ids + request.output_token_ids + can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size + if num_cached_tokens == can_cache_computed_tokens: + return last_node + + mm_idx = 0 + node = last_node + unique_node_ids = [] + new_last_node = last_node + has_unfilled_block = False + current_time = time.time() + + input_hash_value = self.hash_block_features(input_ids) + gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy() + for i in range(num_cached_tokens, can_cache_computed_tokens, block_size): + current_block = input_ids[i : i + block_size] + current_block_size = len(current_block) # 最后一个block可能没填满 + if current_block_size != block_size: + has_unfilled_block = True + else: + mm_idx, extra_keys = self.get_block_hash_extra_keys( + request=request, + start_idx=i, + end_idx=i + block_size, + mm_idx=mm_idx, + ) + hash_value = self.hash_block_features(current_block, extra_keys) + allocated_block_id = gpu_block_ids.pop(0) + node_id = self.node_id_pool.pop() + unique_node_ids.append(node_id) + new_last_node = BlockNode( + node_id, + input_ids, + input_hash_value, + node.depth + 1, + allocated_block_id, + current_block_size, + hash_value, + current_time, + parent=node, + shared_count=1, + reverved_dec_block_ids=[], + ) + new_last_node.req_id_set.add(request.request_id) + self.node_map[node_id] = new_last_node + node.children[hash_value] = new_last_node + node = new_last_node + + reverved_dec_block_ids = [] + if has_unfilled_block is True: + reverved_dec_block_ids.append(gpu_block_ids.pop(0)) + + if new_last_node == self.radix_tree_root: + self.unfilled_req_block_map[request.request_id] = reverved_dec_block_ids + else: + new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids) + logger.info(f"build_path: allocate unique node ids {unique_node_ids} for req_id {request.request_id}") + return new_last_node + def build_path( self, req_id, diff --git a/fastdeploy/config.py b/fastdeploy/config.py index b51b8190f..a2cc620a8 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1137,6 +1137,8 @@ class CacheConfig: enc_dec_block_num (int): Number of encoder-decoder blocks. prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1. enable_prefix_caching (bool): Enable prefix caching. + max_encoder_cache(int): Maximum number of tokens in the encoder cache. + max_processor_cache(int): Maximum number of bytes in the processor cache. """ self.block_size = 64 self.gpu_memory_utilization = 0.9 @@ -1157,6 +1159,8 @@ class CacheConfig: self.enable_ssd_cache = False self.cache_queue_port = None self.swap_space = None + self.max_encoder_cache = None + self.max_processor_cache = None for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) @@ -1440,7 +1444,7 @@ class FDConfig: self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3")) if current_platform.is_xpu(): self.max_prefill_batch = 1 - if self.model_config is not None and self.model_config.enable_mm: + if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化 else: self.max_prefill_batch = self.scheduler_config.max_num_seqs @@ -1500,7 +1504,7 @@ class FDConfig: self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs) self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size) - if self.model_config is not None and self.model_config.enable_mm: + if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.cache_config.enable_prefix_caching = False if ( @@ -1513,6 +1517,20 @@ class FDConfig: else: self.structured_outputs_config.guided_decoding_backend = "xgrammar" + if self.model_config.enable_mm: + if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0: + self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens + elif self.cache_config.max_encoder_cache != 0: + if self.cache_config.max_encoder_cache < self.scheduler_config.max_num_batched_tokens: + logger.warning( + f"max_encoder_cache{self.cache_config.max_encoder_cache} is less than " + f"max_num_batched_tokens{self.scheduler_config.max_num_batched_tokens}, " + f"set to max_num_batched_tokens." + ) + self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens + else: + self.cache_config.max_encoder_cache = 0 + # Adjustment GraphOptConfig if self.load_config is not None and self.load_config.dynamic_load_weight is True: self.graph_opt_config.graph_opt_level = 0 diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 8fbb615a3..90fd47b59 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -123,6 +123,14 @@ class EngineArgs: """ Limitation of numbers of multi-modal data. """ + max_encoder_cache: int = -1 + """ + Maximum number of tokens in the encoder cache. + """ + max_processor_cache: float = -1 + """ + Maximum number of bytes(in GiB) in the processor cache. + """ reasoning_parser: str = None """ specifies the reasoning parser to use for extracting reasoning content from the model output @@ -526,6 +534,18 @@ class EngineArgs: type=json.loads, help="Additional keyword arguments for the multi-modal processor.", ) + model_group.add_argument( + "--max-encoder-cache", + default=EngineArgs.max_encoder_cache, + type=int, + help="Maximum encoder cache tokens(use 0 to disable).", + ) + model_group.add_argument( + "--max-processor-cache", + default=EngineArgs.max_processor_cache, + type=float, + help="Maximum processor cache bytes(use 0 to disable).", + ) model_group.add_argument( "--enable-mm", action=DeprecatedOptionWarning, diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index bfd5e83d4..cf2662e0b 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -634,13 +634,9 @@ class EngineService: int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) - if self.cfg.model_config.enable_mm: - available_blocks = self.resource_manager.available_block_num() - else: - available_blocks = self.cfg.cache_config.max_block_num_per_seq tasks = self.scheduler.get_requests( - available_blocks=available_blocks, + available_blocks=self.cfg.cache_config.max_block_num_per_seq, block_size=self.cfg.cache_config.block_size, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, max_num_batched_tokens=self.cfg.model_config.max_model_len, diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 4799f84e5..1525b48d7 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -528,6 +528,7 @@ class LLMEngine: f" --load_choices {self.cfg.load_config.load_choices}" f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'" f" --ips {ips}" + f" --max_encoder_cache {self.cfg.cache_config.max_encoder_cache}" f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}" f" --runner {self.cfg.model_config.runner}" f" --convert {self.cfg.model_config.convert}" diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index ab5c3103b..f18ac86b2 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -46,6 +46,12 @@ class RequestType(Enum): EXTEND = 3 +@dataclass +class ImagePosition: + offset: int = 0 + length: int = 0 + + @dataclass class Request: def __init__( diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 41f1589ca..989434367 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -28,10 +28,21 @@ import numpy as np import paddle from fastdeploy import envs -from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType +from fastdeploy.cache_manager.multimodal_cache_manager import ( + EncoderCacheManager, + ProcessorCacheManager, +) +from fastdeploy.engine.request import ( + ImagePosition, + Request, + RequestOutput, + RequestStatus, + RequestType, +) from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.inter_communicator import IPCSignal from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.multimodal.hasher import MultimodalHasher from fastdeploy.platforms import current_platform from fastdeploy.utils import llm_logger @@ -175,6 +186,15 @@ class ResourceManagerV1(ResourceManager): self.need_block_num_map = dict() + self.encoder_cache = None + if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0: + self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache) + + self.processor_cache = None + if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0: + max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) + self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) + def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -273,6 +293,44 @@ class ResourceManagerV1(ResourceManager): break return can_schedule + def _update_mm_hashes(self, request): + if request.multimodal_inputs is None: + return + + inputs = request.multimodal_inputs + if ( + inputs.get("images", None) is not None + and inputs.get("image_patch_id", None) is not None + and inputs.get("grid_thw", None) is not None + and len(inputs["grid_thw"]) != 0 + ): + grid_thw = [] + new_mm_positions, new_mm_hashes = [], [] + image_st = 0 + for idx, one in enumerate(inputs["grid_thw"]): + t, h, w = one[0], one[1], one[2] + if t == 1: + grid_thw.append(one) + new_mm_positions.append(inputs["mm_positions"][idx]) + new_mm_hashes.append(inputs["mm_hashes"][idx]) + image_st += h * w + else: + grid_thw.extend([[2, h, w]] * (t // 2)) + token_st = inputs["mm_positions"][idx].offset + for _ in range(t // 2): + new_mm_positions.append(ImagePosition(token_st, h * w // 4)) + # videos are split into patches every 2 frames, need to rehash + new_mm_hashes.append( + MultimodalHasher.hash_features(inputs["images"][image_st : image_st + 2 * h * w]) + ) + image_st += 2 * h * w + token_st += h * w // 4 + inputs["mm_positions"] = new_mm_positions + inputs["mm_hashes"] = new_mm_hashes + else: + inputs["mm_positions"] = [] + inputs["mm_hashes"] = [] + def _get_num_new_tokens(self, request, token_budget): # TODO: set condition to new _get_num_new_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens @@ -333,11 +391,12 @@ class ResourceManagerV1(ResourceManager): if request.multimodal_img_boundaries is None: grid_thw = [] - for one in inputs["grid_thw"]: - if one[0] == 1: + for idx, one in enumerate(inputs["grid_thw"]): + t, h, w = one[0], one[1], one[2] + if t == 1: grid_thw.append(one) else: - grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) + grid_thw.extend([[2, h, w]] * (t // 2)) grid_thw = paddle.to_tensor(grid_thw, dtype="int64") if current_platform.is_xpu(): @@ -398,6 +457,11 @@ class ResourceManagerV1(ResourceManager): request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1)) request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1)) + cur_mm_hashes = inputs["mm_hashes"][request.num_image_start : request.num_image_end] + cur_mm_positions = inputs["mm_positions"][request.num_image_start : request.num_image_end] + if self.encoder_cache: + request.evict_mm_hashes = self.encoder_cache.apply_cache(cur_mm_hashes, cur_mm_positions) + # Compatible with scenarios without images and videos. return num_new_tokens @@ -553,6 +617,7 @@ class ResourceManagerV1(ResourceManager): break request = self.waiting[0] if request.status == RequestStatus.WAITING: + self._update_mm_hashes(request) # Enable prefix caching if self.config.cache_config.enable_prefix_caching: if ( diff --git a/fastdeploy/entrypoints/chat_utils.py b/fastdeploy/entrypoints/chat_utils.py index 9850430a5..08abc8ed3 100644 --- a/fastdeploy/entrypoints/chat_utils.py +++ b/fastdeploy/entrypoints/chat_utils.py @@ -17,7 +17,6 @@ import os import time import uuid -from copy import deepcopy from pathlib import Path from typing import List, Literal, Optional, Union from urllib.parse import urlparse @@ -29,6 +28,7 @@ from openai.types.chat import ( from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam, ) +from openai.types.chat.chat_completion_content_part_image_param import ImageURL from typing_extensions import Required, TypeAlias, TypedDict from fastdeploy.multimodal.image import ImageMediaIO @@ -36,6 +36,17 @@ from fastdeploy.multimodal.video import VideoMediaIO from fastdeploy.utils import api_server_logger +class CustomChatCompletionContentPartImageParam(TypedDict, total=False): + """Custom Image URL object""" + + type: Required[Literal["image_url"]] + """The type of the content part.""" + + image_url: Optional[ImageURL] + + uuid: Optional[str] + + class VideoURL(TypedDict, total=False): """Video URL object""" @@ -46,14 +57,17 @@ class VideoURL(TypedDict, total=False): class CustomChatCompletionContentPartVideoParam(TypedDict, total=False): """Custom Video URL object""" - video_url: Required[VideoURL] - type: Required[Literal["video_url"]] - """The type of the content type.""" + """The type of the content part.""" + + video_url: Optional[VideoURL] + + uuid: Optional[str] CustomChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, + CustomChatCompletionContentPartImageParam, CustomChatCompletionContentPartVideoParam, ] @@ -77,7 +91,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam] -class MultiModalPartParser: +class MultimodalPartParser: """Multi Modal Part parser""" def __init__(self): @@ -139,32 +153,46 @@ def parse_content_part(mm_parser, part): return part if part_type == "image_url": - content = part.get("image_url", {}).get("url", None) - image = mm_parser.parse_image(content) - parsed = deepcopy(part) - del parsed["image_url"]["url"] - parsed["image"] = image - parsed["type"] = "image" - return parsed + if not part.get("image_url", None) and not part.get("uuid", None): + raise ValueError("Both image_url and uuid are missing") + if part.get("image_url", None): + url = part["image_url"]["url"] + image = mm_parser.parse_image(url) + else: + image = None + + parsed = {} + parsed["type"] = "image" + parsed["data"] = image + parsed["uuid"] = part.get("uuid", None) + + return parsed if part_type == "video_url": - content = part.get("video_url", {}).get("url", None) - video = mm_parser.parse_video(content) - parsed = deepcopy(part) - del parsed["video_url"]["url"] - parsed["video"] = video + if not part.get("video_url", None) and not part.get("uuid", None): + raise ValueError("Both video_url and uuid are missing") + + if part.get("video_url", None): + url = part["video_url"]["url"] + video = mm_parser.parse_video(url) + else: + video = None + + parsed = {} parsed["type"] = "video" + parsed["data"] = video + parsed["uuid"] = part.get("uuid", None) + return parsed raise ValueError(f"Unknown content part type: {part_type}") # TODO async -# def parse_chat_messages(messages: List[ChatCompletionMessageParam]): -def parse_chat_messages(messages): +def parse_chat_messages(messages: List[ChatCompletionMessageParam]): """Parse chat messages to [dict]""" - mm_parser = MultiModalPartParser() + mm_parser = MultimodalPartParser() conversation = [] for message in messages: diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 10bc9d32a..e9664ddb6 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -68,16 +68,19 @@ class EngineClient: tool_parser=None, enable_prefix_caching=None, splitwise_role=None, + max_processor_cache=0, ): model_config = ModelConfig({"model": model_name_or_path}) + self.enable_mm = model_config.enable_mm + enable_processor_cache = self.enable_mm and max_processor_cache > 0 input_processor = InputPreprocessor( model_config, reasoning_parser, limit_mm_per_prompt, mm_processor_kwargs, tool_parser, + enable_processor_cache, ) - self.enable_mm = model_config.enable_mm self.enable_logprob = enable_logprob self.reasoning_parser = reasoning_parser self.data_processor = input_processor.create_processor() diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 695107fc7..1537d94ed 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -193,6 +193,7 @@ async def lifespan(app: FastAPI): tool_parser=args.tool_call_parser, enable_prefix_caching=args.enable_prefix_caching, splitwise_role=args.splitwise_role, + max_processor_cache=args.max_processor_cache, ) await engine_client.connection_manager.initialize() app.state.dynamic_load_weight = args.dynamic_load_weight diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 7329af8a4..9e75d23f2 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -470,6 +470,8 @@ class CompletionRequest(BaseModel): max_streaming_response_tokens: Optional[int] = None return_token_ids: Optional[bool] = None prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None + + mm_hashes: Optional[list] = None # doc: end-completion-extra-params def to_dict_for_infer(self, request_id=None, prompt=None): @@ -527,6 +529,9 @@ class CompletionRequest(BaseModel): if item is not None: req_dict[key] = item + if self.mm_hashes is not None and len(self.mm_hashes) > 0: + req_dict["mm_hashes"] = self.mm_hashes + return req_dict @model_validator(mode="before") @@ -553,6 +558,9 @@ class CompletionRequest(BaseModel): "('guided_json', 'guided_regex', 'guided_choice', 'guided_grammar')." ) + if data.get("mm_hashes", None): + assert isinstance(data["mm_hashes"], list), "`mm_hashes` must be a list." + return data @@ -618,6 +626,8 @@ class ChatCompletionRequest(BaseModel): prompt_token_ids: Optional[List[int]] = None max_streaming_response_tokens: Optional[int] = None disable_chat_template: Optional[bool] = False + + mm_hashes: Optional[list] = None completion_token_ids: Optional[List[int]] = None # doc: end-chat-completion-extra-params @@ -694,6 +704,9 @@ class ChatCompletionRequest(BaseModel): if item is not None: req_dict[key] = item + if self.mm_hashes is not None and len(self.mm_hashes) > 0: + req_dict["mm_hashes"] = self.mm_hashes + return req_dict @model_validator(mode="before") @@ -721,6 +734,9 @@ class ChatCompletionRequest(BaseModel): "('guided_json', 'guided_regex', 'guided_choice', 'guided_grammar', 'structural_tag')." ) + if data.get("mm_hashes", None): + assert isinstance(data["mm_hashes"], list), "`mm_hashes` must be a list." + return data @model_validator(mode="before") diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 439b752c2..1211eccf5 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -37,6 +37,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor): mm_processor_kwargs=None, reasoning_parser_obj=None, tool_parser_obj=None, + enable_processor_cache=False, ): data_processor_logger.info(f"model_name_or_path: {model_name_or_path}") tokenizer_path = model_name_or_path @@ -46,6 +47,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor): self.ernie4_5_processor = DataProcessor( tokenizer_name=tokenizer_path, image_preprocessor_name=preprocessor_path, + enable_processor_cache=enable_processor_cache, **processor_kwargs, ) self.ernie4_5_processor.eval() diff --git a/fastdeploy/input/ernie4_5_vl_processor/process.py b/fastdeploy/input/ernie4_5_vl_processor/process.py index c3671c943..5a7767df4 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/process.py +++ b/fastdeploy/input/ernie4_5_vl_processor/process.py @@ -18,16 +18,20 @@ """ process.py """ import copy import os +import pickle from collections import defaultdict -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import zmq from paddleformers.transformers.image_utils import ChannelDimension from PIL import Image +from fastdeploy.engine.request import ImagePosition from fastdeploy.entrypoints.chat_utils import parse_chat_messages from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer from fastdeploy.input.utils import IDS_TYPE_FLAG +from fastdeploy.multimodal.hasher import MultimodalHasher from fastdeploy.utils import data_processor_logger from .image_preprocessor.image_preprocessor_adaptive import AdaptiveImageProcessor @@ -84,6 +88,7 @@ class DataProcessor: self, tokenizer_name: str, image_preprocessor_name: str, + enable_processor_cache: bool = False, spatial_conv_size: int = 2, temporal_conv_size: int = 2, image_min_pixels: int = 4 * 28 * 28, @@ -102,6 +107,7 @@ class DataProcessor: self._load_tokenizer() self.tokenizer.ignored_index = -100 self.image_preprocessor = AdaptiveImageProcessor.from_pretrained(image_preprocessor_name) + self.enable_processor_cache = enable_processor_cache # Convolution sizes for patch aggregation self.spatial_conv_size = spatial_conv_size @@ -163,10 +169,18 @@ class DataProcessor: """Enable evaluation mode (doesn't produce labels).""" self.is_training = False - def text2ids(self, text, images=None, videos=None): + def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None): """ Convert chat text into model inputs. - Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. + + Args: + text (str): The chat text containing placeholders for images and videos. + images (list, optional): List of images to be processed and inserted at image placeholders. + videos (list, optional): List of videos to be processed and inserted at video placeholders. + image_uuid (list, optional): List of unique identifiers for each image, used for caching or hashing. + video_uuid (list, optional): List of unique identifiers for each video, used for caching or hashing. + Returns: + dict: A dictionary with keys input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels, etc. """ outputs = { @@ -178,8 +192,9 @@ class DataProcessor: "image_type_ids": [], "labels": [], "cur_position": 0, - "pic_cnt": 0, "video_cnt": 0, + "mm_positions": [], + "mm_hashes": [], } IMAGE_PLACEHOLDER = "<|image@placeholder|>" @@ -199,17 +214,27 @@ class DataProcessor: break if ed == image_pos: - self._add_image(images[image_idx], outputs) + image = images[image_idx] + uuid = image_uuid[image_idx] if image_uuid else None + if not isinstance(image, tuple): + self._add_image(image, outputs, uuid) + else: + # cached images are already processed + self._add_processed_image(image, outputs, uuid) image_idx += 1 st = ed + IMAGE_PLACEHOLDER_LEN else: item = videos[video_idx] - if isinstance(item, dict): - frames = self._load_and_process_video(item["video"], item) + uuid = video_uuid[video_idx] if video_uuid else None + if not isinstance(item, tuple): + if isinstance(item, dict): + frames = self._load_and_process_video(item["video"], item) + else: + frames = self._load_and_process_video(item, {}) + self._add_video(frames, outputs, uuid) else: - frames = self._load_and_process_video(item, {}) - - self._add_video(frames, outputs) + # cached frames are already processed + self._add_processed_video(item, outputs, uuid) video_idx += 1 st = ed + VIDEO_PLACEHOLDER_LEN @@ -223,66 +248,82 @@ class DataProcessor: Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. """ - outputs = { - "input_ids": [], - "token_type_ids": [], - "position_ids": [], - "images": [], - "grid_thw": [], - "image_type_ids": [], - "labels": [], - "cur_position": 0, - "pic_cnt": 0, - "video_cnt": 0, - } - messages = parse_chat_messages(request.get("messages")) - image_message_list = [] + mm_items = [] for msg in messages: role = msg.get("role") assert role in self.role_prefixes, f"Unsupported role: {role}" - content_items = msg.get("content") - if not isinstance(content_items, list): - content_items = [content_items] - for item in content_items: - if isinstance(item, dict) and item.get("type") in [ - "image", - "video", - ]: - image_message_list.append(item) + content = msg.get("content") + if not isinstance(content, list): + content = [content] + for item in content: + if item.get("type") in ["image", "video"]: + mm_items.append(item) + + missing_hashes, missing_idx = [], [] + for idx, item in enumerate(mm_items): + if not item.get("data"): + # raw data not provided, should be retrieved from processor cache + missing_hashes.append(item.get("uuid")) + missing_idx.append(idx) + + if len(missing_hashes) > 0 and not self.enable_processor_cache: + raise ValueError("Missing items cannot be retrieved without processor cache.") + + if self.enable_processor_cache: + context = zmq.Context() + dealer = context.socket(zmq.DEALER) + dealer.connect("ipc:///dev/shm/processor_cache.ipc") + + missing_items = self.get_processor_cache(dealer, missing_hashes) + for idx in range(len(missing_items)): + if not missing_items[idx]: + raise ValueError(f"Missing item {idx} not found in processor cache") + mm_items[missing_idx[idx]]["data"] = missing_items[idx] + + images, videos = [], [] + image_uuid, video_uuid = [], [] + for item in mm_items: + if item.get("type") == "image": + images.append(item["data"]) + image_uuid.append(item["uuid"]) + elif item.get("type") == "video": + videos.append(item["data"]) + video_uuid.append(item["uuid"]) + else: + raise ValueError(f"Unsupported multimodal type: {item.get('type')}") + + if self.tokenizer.chat_template is None: + raise ValueError("This model does not support chat template.") + chat_template_kwargs = request.get("chat_template_kwargs", {}) - prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs) - if len(prompt_token_ids) == 0: - raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") - image_start_index = 0 - image_message_index = 0 - for i in range(len(prompt_token_ids)): - if prompt_token_ids[i] in [ - self.image_start_id, - self.video_start_id, - ]: - self._add_text(prompt_token_ids[image_start_index : i + 1], outputs) - image_start_index = i + 1 - image_message = image_message_list[image_message_index] - if image_message["type"] == "image": - img = image_message.get("image") - if img is None: - continue - outputs["pic_cnt"] += 1 - self._add_image(img, outputs) - elif image_message["type"] == "video": - video_bytes = image_message.get("video") - if video_bytes is None: - continue - frames = self._load_and_process_video(video_bytes, image_message) - outputs["video_cnt"] += 1 - self._add_video(frames, outputs) - image_message_index += 1 - self._add_text(prompt_token_ids[image_start_index:], outputs) + prompt = self.tokenizer.apply_chat_template( + request, + tokenize=False, + add_generation_prompt=request.get("add_generation_prompt", True), + **chat_template_kwargs, + ) + request["prompt_tokens"] = prompt + + outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid) + + if self.enable_processor_cache: + missing_idx = set(missing_idx) + hashes_to_cache, items_to_cache = [], [] + for idx in range(len(mm_items)): + if idx in missing_idx: + continue + meta = {} + t, h, w = outputs["grid_thw"][idx][0] + meta["thw"] = (t, h, w) + hashes_to_cache.append(outputs["mm_hashes"][idx]) + items_to_cache.append((outputs["images"][idx], meta)) + self.update_processor_cache(dealer, hashes_to_cache, items_to_cache) if self.is_training: - assert tgts, "training must give tgt !" + assert tgts, "Training must give tgt" self._extract_labels(outputs, tgts) + return outputs def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: @@ -304,7 +345,7 @@ class DataProcessor: outputs["position_ids"].append([start + i] * 3) outputs["cur_position"] += len(tokens) - def _add_image(self, img, outputs: Dict) -> None: + def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: patches_h, patches_w = self.image_preprocessor.get_smarted_resize( img.height, img.width, @@ -313,6 +354,7 @@ class DataProcessor: )[1] num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2) + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.image_patch_id] * num_tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) @@ -330,10 +372,32 @@ class DataProcessor: input_data_format=ChannelDimension.LAST, ) outputs["images"].append(ret["pixel_values"]) + if not uuid: + outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"])) + else: + outputs["mm_hashes"].append(uuid) outputs["grid_thw"].append(ret["image_grid_thw"]) outputs["image_type_ids"].append(0) - def _add_video(self, frames, outputs: Dict) -> None: + def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + img, meta = img_cache + num_tokens = img.shape[0] // (self.spatial_conv_size**2) + + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) + outputs["input_ids"].extend([self.image_patch_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) + + _, h, w = meta["thw"] + pos_ids = self._compute_3d_positions(1, h, w, outputs["cur_position"]) + outputs["position_ids"].extend(pos_ids) + outputs["cur_position"] = np.max(pos_ids) + 1 + + outputs["images"].append(img) + outputs["mm_hashes"].append(uuid) + outputs["grid_thw"].append(np.array([[1, h, w]])) + outputs["image_type_ids"].append(0) + + def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None: patches_h, patches_w = self.image_preprocessor.get_smarted_resize( frames[0].height, frames[0].width, @@ -354,9 +418,14 @@ class DataProcessor: input_data_format=ChannelDimension.LAST, ) outputs["images"].append(ret["pixel_values_videos"]) + if not uuid: + outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values_videos"])) + else: + outputs["mm_hashes"].append(uuid) outputs["grid_thw"].append(ret["video_grid_thw"]) outputs["image_type_ids"].extend([1] * num_frames) + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.image_patch_id] * num_tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) @@ -364,6 +433,24 @@ class DataProcessor: outputs["position_ids"].extend(pos_ids) outputs["cur_position"] = np.max(pos_ids) + 1 + def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + frames, meta = frames_cache + num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size) + + t, h, w = meta["thw"] + outputs["images"].append(frames) + outputs["mm_hashes"].append(uuid) + outputs["grid_thw"].append(np.array([[t, h, w]])) + + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) + outputs["input_ids"].extend([self.image_patch_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) + outputs["image_type_ids"].extend([1] * t) + + pos_ids = self._compute_3d_positions(t, h, w, outputs["cur_position"]) + outputs["position_ids"].extend(pos_ids) + outputs["cur_position"] = np.max(pos_ids) + 1 + def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None: input_ids = copy.deepcopy(outputs["input_ids"]) labels = [self.tokenizer.ignored_index] * len(input_ids) @@ -480,33 +567,22 @@ class DataProcessor: break self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path) - def apply_chat_template(self, request, **kwargs): + def get_processor_cache(self, socket, mm_hashes: list[str]) -> list: """ - Convert multi-turn messages into ID sequences. - - Args: - messages: Either a request dict containing 'messages' field, - or a list of message dicts directly - - Returns: - List of token IDs as strings (converted from token objects) + get cache correspond to given hash values """ - if self.tokenizer.chat_template is None: - raise ValueError("This model does not support chat_template.") + req = pickle.dumps(mm_hashes) + socket.send_multipart([b"", req]) + _, resp = socket.recv_multipart() + mm_items = pickle.loads(resp) + data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}") - prompt_token_template = self.tokenizer.apply_chat_template( - request, - tokenize=False, - add_generation_prompt=request.get("add_generation_prompt", True), - **kwargs, - ) - prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace( - "<|video@placeholder|>", "" - ) - request["prompt_tokens"] = prompt_token_template - tokens = self.tokenizer.tokenize(prompt_token_str) - token_ids = self.tokenizer.convert_tokens_to_ids(tokens) - data_processor_logger.info( - f"req_id:{request.get('request_id', ''), } tokens: {tokens}, token_ids: {token_ids}" - ) - return token_ids + return mm_items + + def update_processor_cache(self, socket, mm_hashes: list[str], mm_items): + """ + update cache data + """ + req = pickle.dumps((mm_hashes, mm_items)) + socket.send_multipart([b"", req]) + data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}") diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index 5ee64af41..1abb804ac 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -46,6 +46,7 @@ class InputPreprocessor: limit_mm_per_prompt: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, tool_parser: str = None, + enable_processor_cache: bool = False, ) -> None: self.model_config = model_config self.model_name_or_path = self.model_config.model @@ -53,6 +54,7 @@ class InputPreprocessor: self.limit_mm_per_prompt = limit_mm_per_prompt self.mm_processor_kwargs = mm_processor_kwargs self.tool_parser = tool_parser + self.enable_processor_cache = enable_processor_cache def create_processor(self): reasoning_parser_obj = None @@ -104,6 +106,19 @@ class InputPreprocessor: mm_processor_kwargs=self.mm_processor_kwargs, reasoning_parser_obj=reasoning_parser_obj, tool_parser_obj=tool_parser_obj, + enable_processor_cache=self.enable_processor_cache, + ) + elif "PaddleOCRVL" in architecture: + from fastdeploy.input.paddleocr_vl_processor import ( + PaddleOCRVLProcessor, + ) + + self.processor = PaddleOCRVLProcessor( + config=self.model_config, + model_name_or_path=self.model_name_or_path, + limit_mm_per_prompt=self.limit_mm_per_prompt, + mm_processor_kwargs=self.mm_processor_kwargs, + reasoning_parser_obj=reasoning_parser_obj, ) elif "PaddleOCRVL" in architecture: from fastdeploy.input.paddleocr_vl_processor import ( @@ -126,5 +141,6 @@ class InputPreprocessor: limit_mm_per_prompt=self.limit_mm_per_prompt, mm_processor_kwargs=self.mm_processor_kwargs, reasoning_parser_obj=reasoning_parser_obj, + enable_processor_cache=self.enable_processor_cache, ) return self.processor diff --git a/fastdeploy/input/qwen_vl_processor/process.py b/fastdeploy/input/qwen_vl_processor/process.py index 53a9381ae..88a0d7b50 100644 --- a/fastdeploy/input/qwen_vl_processor/process.py +++ b/fastdeploy/input/qwen_vl_processor/process.py @@ -15,17 +15,23 @@ # limitations under the License. """ -from typing import Any, Dict, List, Tuple, Union +import pickle +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import zmq from paddleformers.transformers import AutoTokenizer +from PIL import Image +from fastdeploy.engine.request import ImagePosition from fastdeploy.entrypoints.chat_utils import parse_chat_messages +from fastdeploy.input.ernie4_5_vl_processor import read_video_decord from fastdeploy.input.utils import IDS_TYPE_FLAG +from fastdeploy.multimodal.hasher import MultimodalHasher from fastdeploy.utils import data_processor_logger from .image_processor import ImageProcessor -from .process_video import read_frames, sample_frames +from .process_video import sample_frames class DataProcessor: @@ -49,8 +55,11 @@ class DataProcessor: def __init__( self, model_path: str, + enable_processor_cache: bool = False, video_min_frames: int = 4, video_max_frames: int = 768, + video_target_frames: int = -1, + video_fps: int = -1, tokens_per_second: int = 2, tokenizer=None, **kwargs, @@ -67,6 +76,8 @@ class DataProcessor: """ self.min_frames = video_min_frames self.max_frames = video_max_frames + self.target_frames = video_target_frames + self.fps = video_fps # Initialize tokenizer with left padding and fast tokenizer if tokenizer is None: @@ -75,6 +86,7 @@ class DataProcessor: else: self.tokenizer = tokenizer self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor + self.enable_processor_cache = enable_processor_cache # Convolution sizes for patch aggregation self.spatial_conv_size = self.image_processor.merge_size @@ -99,41 +111,7 @@ class DataProcessor: "assistant": "Assistant: ", } - def _pack_outputs(self, outputs): - """ - Pack and convert all output data into numpy arrays with appropriate types. - - Args: - outputs (dict): Dictionary containing model outputs with keys: - - images: List of visual features - - grid_thw: List of spatial dimensions - - image_type_ids: List of content type indicators - - input_ids: List of token IDs - - token_type_ids: List of type identifiers - - position_ids: List of position embeddings - - Returns: - dict: Processed outputs with all values converted to numpy arrays - """ - # Process visual outputs - stack if exists or set to None if empty - if not outputs["images"]: - outputs["images"] = None # No images case - outputs["grid_thw"] = None # No spatial dimensions - outputs["image_type_ids"] = None # No type IDs - else: - outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically - outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions - outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array - - # Convert all outputs to numpy arrays with appropriate types - outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64 - outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64 - outputs["position_ids"] = np.concatenate( - outputs["position_ids"], axis=1, dtype=np.int64 - ) # Concatenate position IDs - return outputs - - def text2ids(self, text, images=None, videos=None): + def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None): """ Convert text with image/video placeholders into model inputs. @@ -141,6 +119,8 @@ class DataProcessor: text: Input text with <|image@placeholder|> and <|video@placeholder|> markers images: List of PIL Images corresponding to image placeholders videos: List of video data corresponding to video placeholders + image_uuid: List of unique identifiers for each image, used for caching or hashing. + video_uuid: List of unique identifiers for each video, used for caching or hashing. Returns: Dict containing: @@ -161,8 +141,10 @@ class DataProcessor: "image_type_ids": [], "labels": [], "cur_position": 0, - "pic_cnt": 0, "video_cnt": 0, + "fps": [], + "mm_positions": [], + "mm_hashes": [], } # Define placeholders and their lengths @@ -186,23 +168,30 @@ class DataProcessor: break if ed == image_pos: - outputs["pic_cnt"] += 1 - self._add_image(images[image_idx], outputs) + image = images[image_idx] + uuid = image_uuid[image_idx] if image_uuid else None + if not isinstance(image, tuple): + self._add_image(image, outputs, uuid) + else: + self._add_processed_image(image, outputs, uuid) image_idx += 1 st = ed + IMAGE_PLACEHOLDER_LEN else: item = videos[video_idx] - if isinstance(item, dict): - frames, meta = self._load_and_process_video(item["video"], item) + uuid = video_uuid[video_idx] if video_uuid else None + if not isinstance(item, tuple): + if isinstance(item, dict): + frames, meta = self._load_and_process_video(item["video"], item) + else: + frames, meta = self._load_and_process_video(item, {}) + self._add_video(frames, meta, outputs, uuid) else: - frames, meta = self._load_and_process_video(item, {}) - - outputs["video_cnt"] += 1 - self._add_video(frames, meta, outputs) + # cached frames are already processed + self._add_processed_video(item, outputs, uuid) video_idx += 1 st = ed + VIDEO_PLACEHOLDER_LEN - return self._pack_outputs(outputs) + return outputs def request2ids( self, request: Dict[str, Any], tgts: List[str] = None @@ -220,74 +209,84 @@ class DataProcessor: Dict with same structure as text2ids() output """ - outputs = { - "input_ids": [], - "token_type_ids": [], - "position_ids": [], - "images": [], - "grid_thw": [], - "image_type_ids": [], - "labels": [], - "cur_position": 0, - "pic_cnt": 0, - "video_cnt": 0, - } - # Parse and validate chat messages messages = parse_chat_messages(request.get("messages")) - image_message_list = [] # Store visual content messages - + mm_items = [] for msg in messages: role = msg.get("role") assert role in self.role_prefixes, f"Unsupported role: {role}" # Normalize content to list format - content_items = msg.get("content") - if not isinstance(content_items, list): - content_items = [content_items] - + content = msg.get("content") + if not isinstance(content, list): + content = [content] # Collect all visual content items - for item in content_items: - if isinstance(item, dict) and item.get("type") in ["image", "video"]: - image_message_list.append(item) + for item in content: + if item.get("type") in ["image", "video"]: + mm_items.append(item) - raw_messages = request["messages"] - request["messages"] = messages + missing_hashes, missing_idx = [], [] + for idx, item in enumerate(mm_items): + if not item.get("data"): + # raw data not provided, should be retrieved from processor cache + missing_hashes.append(item.get("uuid")) + missing_idx.append(idx) - prompt_token_ids = self.apply_chat_template(request) - if len(prompt_token_ids) == 0: - raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") - request["messages"] = raw_messages + if len(missing_hashes) > 0 and not self.enable_processor_cache: + raise ValueError("Missing items cannot be retrieved without processor cache.") - vision_start_index = 0 - vision_message_index = 0 - for i in range(len(prompt_token_ids)): - if prompt_token_ids[i] == self.vision_start_id: - self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs) + if self.enable_processor_cache: + context = zmq.Context() + dealer = context.socket(zmq.DEALER) + dealer.connect("ipc:///dev/shm/processor_cache.ipc") - vision_start_index = i + 1 - image_message = image_message_list[vision_message_index] + missing_items = self.get_processor_cache(dealer, missing_hashes) + for idx in range(len(missing_items)): + if not missing_items[idx]: + raise ValueError(f"Missing item {idx} not found in processor cache") + mm_items[missing_idx[idx]]["data"] = missing_items[idx] - if image_message["type"] == "image": - img = image_message.get("image") - if img is None: - continue - outputs["pic_cnt"] += 1 - self._add_image(img, outputs) + images, videos = [], [] + image_uuid, video_uuid = [], [] + for item in mm_items: + if item.get("type") == "image": + images.append(item["data"]) + image_uuid.append(item["uuid"]) + elif item.get("type") == "video": + videos.append(item["data"]) + video_uuid.append(item["uuid"]) + else: + raise ValueError(f"Unsupported multimodal type: {item.get('type')}") - elif image_message["type"] == "video": - video_bytes = image_message.get("video") - if video_bytes is None: - continue - frames, meta = self._load_and_process_video(video_bytes, image_message) + if self.tokenizer.chat_template is None: + raise ValueError("This model does not support chat template.") - outputs["video_cnt"] += 1 - self._add_video(frames, meta, outputs) + chat_template_kwargs = request.get("chat_template_kwargs", {}) + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=request.get("add_generation_prompt", True), + **chat_template_kwargs, + ) + request["prompt_tokens"] = prompt - vision_message_index += 1 + outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid) - self._add_text(prompt_token_ids[vision_start_index:], outputs) - return self._pack_outputs(outputs) + if self.enable_processor_cache: + missing_idx = set(missing_idx) + hashes_to_cache, items_to_cache = [], [] + for idx in range(len(mm_items)): + if idx in missing_idx: + continue + meta = {} + t, h, w = outputs["grid_thw"][idx] + meta["thw"] = (t, h, w) + meta["fps"] = outputs["fps"][idx] + hashes_to_cache.append(outputs["mm_hashes"][idx]) + items_to_cache.append((outputs["images"][idx], meta)) + self.update_processor_cache(dealer, hashes_to_cache, items_to_cache) + + return outputs def _add_text(self, tokens, outputs: Dict) -> None: """ @@ -312,9 +311,9 @@ class DataProcessor: outputs["input_ids"].extend(tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens) - position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens) - outputs["position_ids"].append(position_ids) - outputs["cur_position"] = position_ids.max() + 1 + pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens) + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray: """ @@ -332,7 +331,7 @@ class DataProcessor: position = text_index + start_pos return position - def _add_image(self, img, outputs: Dict) -> None: + def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: """ Add image data to model inputs dictionary. @@ -349,20 +348,47 @@ class DataProcessor: num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2 grid_thw = ret["grid_thw"].tolist() + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.image_token_id] * num_tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) outputs["images"].append(ret["pixel_values"]) + if not uuid: + outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"])) + else: + outputs["mm_hashes"].append(uuid) outputs["grid_thw"].append(grid_thw) outputs["image_type_ids"].append(0) t, h, w = grid_thw - position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0) + pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0) - outputs["position_ids"].append(position_ids) - outputs["cur_position"] = position_ids.max() + 1 + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 - def _add_video(self, frames, meta: Dict, outputs: Dict) -> None: + outputs["fps"].append(0) + + def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + img, meta = img_cache + num_tokens = img.shape[0] // self.image_processor.merge_size**2 + + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) + outputs["input_ids"].extend([self.image_patch_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) + + _, h, w = meta["thw"] + pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0) + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 + + outputs["images"].append(img) + outputs["mm_hashes"].append(uuid) + outputs["grid_thw"].append(np.array([[1, h, w]])) + outputs["image_type_ids"].append(0) + + outputs["fps"].append(0) + + def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) -> None: """ Add video data to model inputs dictionary. @@ -380,20 +406,49 @@ class DataProcessor: num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2 grid_thw = ret["grid_thw"].tolist() + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.video_token_id] * num_tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) outputs["images"].append(ret["pixel_values"]) + if not uuid: + outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"])) + else: + outputs["mm_hashes"].append(uuid) outputs["grid_thw"].append(grid_thw) outputs["image_type_ids"].extend([1] * grid_thw[0]) fps = meta["fps"] second_per_grid_t = self.temporal_conv_size / fps t, h, w = grid_thw - position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t) + pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t) - outputs["position_ids"].append(position_ids) - outputs["cur_position"] = position_ids.max() + 1 + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 + + outputs["fps"].append(fps) + + def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + frames, meta = frames_cache + num_tokens = frames.shape[0] // self.image_processor.merge_size**2 + + t, h, w = meta["thw"] + outputs["images"].append(frames) + outputs["mm_hashes"].append(uuid) + outputs["grid_thw"].append(np.array([[t, h, w]])) + + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) + outputs["input_ids"].extend([self.image_patch_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) + outputs["image_type_ids"].extend([1] * t) + + fps = meta["fps"] + second_per_grid_t = self.temporal_conv_size / fps + pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t) + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 + + outputs["fps"].append(fps) def _compute_vision_positions( self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float @@ -441,20 +496,20 @@ class DataProcessor: - frames: Processed video frames as numpy array - metadata: Updated video metadata dictionary """ - frames, meta = read_frames(url) + reader, meta, _ = read_video_decord(url, save_to_disk=False) # Apply frame sampling if fps or target_frames specified - fps = item.get("fps", None) - num_frames = item.get("target_frames", None) + fps = item.get("fps", self.fps) + num_frames = item.get("target_frames", self.target_frames) - if fps is not None or num_frames is not None: + frame_indices = list(range(meta["num_of_frame"])) + if fps > 0 or num_frames > 0: # Get frame sampling constraints min_frames = item.get("min_frames", self.min_frames) max_frames = item.get("max_frames", self.max_frames) # Sample frames according to specifications - frames = sample_frames( - video=frames, + frame_indices = sample_frames( frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size min_frames=min_frames, max_frames=max_frames, @@ -464,42 +519,38 @@ class DataProcessor: ) # Update metadata with new frame count and fps - meta["num_of_frame"] = frames.shape[0] + meta["num_of_frame"] = len(frame_indices) if fps is not None: meta["fps"] = fps # Use specified fps - meta["duration"] = frames.shape[0] / fps + meta["duration"] = len(frame_indices) / fps else: - meta["fps"] = frames.shape[0] / meta["duration"] # Calculate fps from sampled frames + meta["fps"] = len(frame_indices) / meta["duration"] # Calculate fps from sampled frames + + frames = [] + for idx in frame_indices: + frame = reader[idx].asnumpy() + image = Image.fromarray(frame, "RGB") + frames.append(image) + frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) return frames, meta - def apply_chat_template(self, request): + def get_processor_cache(self, socket, mm_hashes: list[str]) -> list: """ - Apply chat template to convert messages into token sequence. - - Args: - request: Dictionary containing chat messages - - Returns: - List of token IDs - - Raises: - ValueError: If model doesn't support chat templates + get cache correspond to given hash values """ - if self.tokenizer.chat_template is None: - raise ValueError("This model does not support chat_template.") + req = pickle.dumps(mm_hashes) + socket.send_multipart([b"", req]) + _, resp = socket.recv_multipart() + mm_items = pickle.loads(resp) + data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}") - raw_prompt = self.tokenizer.apply_chat_template( - request["messages"], - tokenize=False, - add_generation_prompt=request.get("add_generation_prompt", True), - ) - prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "") - request["prompt_tokens"] = raw_prompt + return mm_items - tokens = self.tokenizer.tokenize(prompt_token_str) - token_ids = self.tokenizer.convert_tokens_to_ids(tokens) - data_processor_logger.info( - f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}" - ) - return token_ids + def update_processor_cache(self, socket, mm_hashes: list[str], mm_items): + """ + update cache data + """ + req = pickle.dumps((mm_hashes, mm_items)) + socket.send_multipart([b"", req]) + data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}") diff --git a/fastdeploy/input/qwen_vl_processor/process_video.py b/fastdeploy/input/qwen_vl_processor/process_video.py index e6a39a23a..c7089d26d 100644 --- a/fastdeploy/input/qwen_vl_processor/process_video.py +++ b/fastdeploy/input/qwen_vl_processor/process_video.py @@ -18,50 +18,9 @@ import math from typing import Optional, Union import numpy as np -from PIL import Image - -from fastdeploy.input.ernie4_5_vl_processor import read_video_decord - - -def read_frames(video_path): - """ - Read and decode video frames from the given path - - This function reads a video file and decodes it into individual RGB frames - using decord video reader. It also extracts video metadata including fps, - duration and frame count. - - Args: - video_path (str): Path to the video file or bytes object containing video data - - Returns: - tuple: A tuple containing: - frames (numpy.ndarray): Array of shape (num_frames, height, width, 3) - containing decoded RGB video frames - meta (dict): Dictionary containing video metadata: - - fps (float): Frames per second - - duration (float): Video duration in seconds - - num_of_frame (int): Total number of frames - - width (int): Frame width in pixels - - height (int): Frame height in pixels - - Note: - - The function uses decord library for efficient video reading - - All frames are converted to RGB format regardless of input format - """ - reader, meta, _ = read_video_decord(video_path, save_to_disk=False) - - frames = [] - for i in range(meta["num_of_frame"]): - frame = reader[i].asnumpy() - image = Image.fromarray(frame, "RGB") - frames.append(image) - frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) - return frames, meta def sample_frames( - video: np.ndarray, frame_factor: int, min_frames: int, max_frames: int, @@ -73,7 +32,6 @@ def sample_frames( Sample frames from video according to specified criteria. Args: - video: Input video frames as numpy array frame_factor: Ensure sampled frames are multiples of this factor min_frames: Minimum number of frames to sample max_frames: Maximum number of frames to sample @@ -89,18 +47,15 @@ def sample_frames( or if required metadata is missing, or if requested frames exceed available frames """ - if fps is not None and num_frames is not None: + if fps > 0 and num_frames > 0: raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") - if fps is None and num_frames is None: - return video - - total_num_frames = video.shape[0] + total_num_frames = metadata["num_of_frame"] # If num_frames is not given but fps is, calculate num_frames from fps - if num_frames is not None: + if num_frames > 0: num_frames = round(num_frames / frame_factor) * frame_factor - elif fps is not None: + elif fps > 0: if metadata is None: raise ValueError( "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " @@ -110,7 +65,6 @@ def sample_frames( num_frames = total_num_frames / metadata["fps"] * fps num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames) num_frames = math.floor(num_frames / frame_factor) * frame_factor - if num_frames > total_num_frames: raise ValueError( f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. " @@ -118,14 +72,11 @@ def sample_frames( ) # Calculate frame indices based on sampling strategy - if num_frames is not None: + if num_frames > 0: # Evenly spaced sampling for target frame count indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32) else: # Keep all frames if no sampling requested indices = np.arange(0, total_num_frames).astype(np.int32) - # Apply frame selection - video = video[indices] - - return video + return indices diff --git a/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py b/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py index 115ff9569..cc649c9ba 100644 --- a/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py +++ b/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py @@ -47,6 +47,7 @@ class QwenVLProcessor(TextProcessor): mm_processor_kwargs=None, reasoning_parser_obj=None, tool_parser_obj=None, + enable_processor_cache=False, ): """ Initialize QwenVLProcessor instance. @@ -65,6 +66,7 @@ class QwenVLProcessor(TextProcessor): processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs) self.processor = DataProcessor( model_path=model_name_or_path, + enable_processor_cache=enable_processor_cache, tokens_per_second=config.vision_config.tokens_per_second, tokenizer=self.tokenizer, **processor_kwargs, @@ -271,7 +273,7 @@ class QwenVLProcessor(TextProcessor): return request - def append_completion_tokens(self, outputs, completion_token_ids): + def append_completion_tokens(self, multimodal_inputs, completion_token_ids): """ Append completion tokens to existing outputs. @@ -279,19 +281,14 @@ class QwenVLProcessor(TextProcessor): outputs: Current model outputs completion_token_ids: completion tokens to append """ - out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]} - self.processor._add_text(completion_token_ids, out) - outputs["input_ids"] = np.concatenate( - [outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0 - ) - outputs["token_type_ids"] = np.concatenate( - [outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0 - ) - outputs["position_ids"] = np.concatenate( - [outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64 - ) - outputs["cur_position"] = out["cur_position"] + num_tokens = len(completion_token_ids) + multimodal_inputs["input_ids"].extend(completion_token_ids) + multimodal_inputs["token_type_ids"].extend([0] * num_tokens) + + pos_ids = self.processor._compute_text_positions(multimodal_inputs["cur_position"], num_tokens) + multimodal_inputs["position_ids"].append(pos_ids) + multimodal_inputs["cur_position"] += num_tokens def pack_outputs(self, outputs): """ @@ -303,7 +300,24 @@ class QwenVLProcessor(TextProcessor): Returns: dict: Packed output dictionary with all required fields """ + if not outputs["images"]: + outputs["images"] = None # No images case + outputs["grid_thw"] = None # No spatial dimensions + outputs["image_type_ids"] = None # No type IDs + else: + outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically + outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions + outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array + + # Convert all outputs to numpy arrays with appropriate types + outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64 + outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64 + outputs["position_ids"] = np.concatenate( + outputs["position_ids"], axis=1, dtype=np.int64 + ) # Concatenate position ID + outputs["image_patch_id"] = self.processor.image_token_id outputs["video_patch_id"] = self.processor.video_token_id outputs["position_ids"] = outputs["position_ids"].transpose(1, 0) + return outputs diff --git a/fastdeploy/multimodal/audio.py b/fastdeploy/multimodal/audio.py index 97c73b26e..0d65bada4 100644 --- a/fastdeploy/multimodal/audio.py +++ b/fastdeploy/multimodal/audio.py @@ -1,7 +1,7 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License" +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # diff --git a/fastdeploy/multimodal/base.py b/fastdeploy/multimodal/base.py index 962b186d2..b58b18362 100644 --- a/fastdeploy/multimodal/base.py +++ b/fastdeploy/multimodal/base.py @@ -1,7 +1,7 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License" +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # diff --git a/fastdeploy/multimodal/hasher.py b/fastdeploy/multimodal/hasher.py new file mode 100644 index 000000000..1f2d01f8c --- /dev/null +++ b/fastdeploy/multimodal/hasher.py @@ -0,0 +1,35 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import hashlib +import pickle + +import numpy as np + +from fastdeploy.utils import data_processor_logger + + +class MultimodalHasher: + + @classmethod + def hash_features(cls, obj: object) -> str: + if isinstance(obj, np.ndarray): + return hashlib.sha256((obj.tobytes())).hexdigest() + + data_processor_logger.warning( + f"Unsupported type for hashing features: {type(obj)}" + ", use pickle for serialization" + ) + return hashlib.sha256((pickle.dumps(obj))).hexdigest() diff --git a/fastdeploy/multimodal/image.py b/fastdeploy/multimodal/image.py index 908e55489..cfbc40de0 100644 --- a/fastdeploy/multimodal/image.py +++ b/fastdeploy/multimodal/image.py @@ -1,7 +1,7 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License" +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # diff --git a/fastdeploy/multimodal/registry.py b/fastdeploy/multimodal/registry.py new file mode 100644 index 000000000..cc84449ad --- /dev/null +++ b/fastdeploy/multimodal/registry.py @@ -0,0 +1,36 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + + +class MultimodalRegistry: + """ + A registry for multimodal models + """ + + mm_models: set[str] = { + "Ernie4_5_VLMoeForConditionalGeneration", + "Ernie5MoeForCausalLM", + "Qwen2_5_VLForConditionalGeneration", + "Ernie5ForCausalLM", + "Ernie4_5_VLMoeForProcessRewardModel", + } + + @classmethod + def contains_model(cls, name: str) -> bool: + """ + Check if the given name exists in registry. + """ + return name in cls.mm_models diff --git a/fastdeploy/multimodal/utils.py b/fastdeploy/multimodal/utils.py index ea45bd710..fa67be2a3 100644 --- a/fastdeploy/multimodal/utils.py +++ b/fastdeploy/multimodal/utils.py @@ -1,7 +1,7 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License" +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # diff --git a/fastdeploy/multimodal/video.py b/fastdeploy/multimodal/video.py index b1aacc2a1..583bfc5ea 100644 --- a/fastdeploy/multimodal/video.py +++ b/fastdeploy/multimodal/video.py @@ -1,7 +1,7 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License" +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # diff --git a/fastdeploy/scheduler/global_scheduler.py b/fastdeploy/scheduler/global_scheduler.py index acdd99767..1debc0a11 100644 --- a/fastdeploy/scheduler/global_scheduler.py +++ b/fastdeploy/scheduler/global_scheduler.py @@ -24,13 +24,12 @@ from typing import Dict, List, Optional, Tuple import crcmod from redis import ConnectionPool -from fastdeploy import envs from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler import utils from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse from fastdeploy.scheduler.storage import AdaptedRedis from fastdeploy.scheduler.workers import Task, Workers -from fastdeploy.utils import scheduler_logger +from fastdeploy.utils import envs, scheduler_logger class GlobalScheduler: @@ -534,32 +533,33 @@ class GlobalScheduler: continue request: ScheduledRequest = ScheduledRequest.unserialize(serialized_request) - required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) - current_prefill_tokens += request.prompt_tokens_ids_len - required_total_blocks += required_input_blocks + reserved_output_blocks + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks - if required_total_blocks > available_blocks: - remaining_request.append((request_queue_name, serialized_request)) - continue + if required_total_blocks > available_blocks: + remaining_request.append((request_queue_name, serialized_request)) + continue - if not envs.FD_ENABLE_MAX_PREFILL: - if self.enable_chunked_prefill: - if request.prompt_tokens_ids_len > self.long_prefill_token_threshold: - long_partial_requests += 1 - if long_partial_requests > self.max_long_partial_prefills: + if not envs.FD_ENABLE_MAX_PREFILL: + if self.enable_chunked_prefill: + if request.prompt_tokens_ids_len > self.long_prefill_token_threshold: + long_partial_requests += 1 + if long_partial_requests > self.max_long_partial_prefills: + remaining_request.append((request_queue_name, serialized_request)) + continue + else: + short_partial_requests += 1 + + if short_partial_requests + long_partial_requests > self.max_num_partial_prefills: remaining_request.append((request_queue_name, serialized_request)) continue else: - short_partial_requests += 1 - - if short_partial_requests + long_partial_requests > self.max_num_partial_prefills: - remaining_request.append((request_queue_name, serialized_request)) - continue - else: - if current_prefill_tokens > max_num_batched_tokens: - remaining_request.append((request_queue_name, serialized_request)) - continue + if current_prefill_tokens > max_num_batched_tokens: + remaining_request.append((request_queue_name, serialized_request)) + continue scheduled_requests.append(request) diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index b77a0dabd..26989f3dc 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -18,10 +18,9 @@ import threading import time from typing import Dict, List, Optional, Tuple -from fastdeploy import envs from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse -from fastdeploy.utils import scheduler_logger +from fastdeploy.utils import envs, scheduler_logger class LocalScheduler: @@ -247,35 +246,40 @@ class LocalScheduler: self.wait_request_timeout, ) - required_total_blocks = 0 - current_prefill_tokens = 0 requests: List[Request] = [] - long_partial_requests, short_partial_requests = 0, 0 - for request_id in batch_ids: - request = self.requests[request_id] - required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) - current_prefill_tokens += request.prompt_tokens_ids_len - required_total_blocks += required_input_blocks + reserved_output_blocks - if required_total_blocks > available_blocks: - break + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + required_total_blocks = 0 + current_prefill_tokens = 0 + long_partial_requests, short_partial_requests = 0, 0 + for request_id in batch_ids: + request = self.requests[request_id] + required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks + if required_total_blocks > available_blocks: + break - if not envs.FD_ENABLE_MAX_PREFILL: - if self.enable_chunked_prefill: - if request.prompt_tokens_ids_len > self.long_prefill_token_threshold: - # 长请求 - long_partial_requests += 1 - if long_partial_requests > self.max_long_partial_prefills: + if not envs.FD_ENABLE_MAX_PREFILL: + if self.enable_chunked_prefill: + if request.prompt_tokens_ids_len > self.long_prefill_token_threshold: + # 长请求 + long_partial_requests += 1 + if long_partial_requests > self.max_long_partial_prefills: + break + else: + short_partial_requests += 1 + + if short_partial_requests + long_partial_requests > self.max_num_partial_prefills: break else: - short_partial_requests += 1 + if current_prefill_tokens > max_num_batched_tokens: + break + requests.append(request.raw) + else: + for request_id in batch_ids: + request = self.requests[request_id] + requests.append(request.raw) - if short_partial_requests + long_partial_requests > self.max_num_partial_prefills: - break - else: - if current_prefill_tokens > max_num_batched_tokens: - break - - requests.append(request.raw) self.ids_read_cursor += len(requests) if len(batch_ids) > 0 and len(requests) == 0: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a20e57b2a..db56abf57 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -137,6 +137,11 @@ class GPUModelRunner(ModelRunnerBase): "fused_gemm_epilogue", ] + if self.cache_config.max_encoder_cache > 0: + self.encoder_cache: dict[str, paddle.Tensor] = {} + else: + self.encoder_cache = None + # Sampler if not self.speculative_decoding: self.sampler = Sampler(fd_config) @@ -297,6 +302,185 @@ class GPUModelRunner(ModelRunnerBase): schemata_key, ) + def get_chunked_inputs(self, req: Request): + """ + Get inputs in current chunk + """ + prefill_start_index = req.prefill_start_index + prefill_end_index = req.prefill_end_index + inputs = req.multimodal_inputs + input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index] + token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index] + image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end] + images = inputs["images"][req.image_start : req.image_end] + grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end] + mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end] + + return ( + input_ids, + token_type_ids, + image_type_ids, + images, + grid_thw, + mm_hashes, + ) + + def batch_uncached_inputs(self, req: Request): + """ + Batch uncached multimodal inputs + """ + (input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req) + + image_type_ids_size = grid_thw[:, 0] + image_type_ids_split = np.cumsum(image_type_ids_size)[:-1] + image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0) + + images_size = np.prod(grid_thw, axis=1) + images_split = np.cumsum(images_size)[:-1] + images_lst = np.array_split(images, images_split, axis=0) + + assert len(image_type_ids_lst) == len( + mm_hashes + ), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}" + assert len(images_lst) == len( + mm_hashes + ), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}" + + uncached_image_type_ids = [] + uncached_images = [] + uncached_grid_thw = [] + uncached_mm_hashes = [] + for i, mm_hash in enumerate(mm_hashes): + if mm_hash in self.encoder_cache: + continue + uncached_image_type_ids.append(image_type_ids_lst[i]) + uncached_images.append(images_lst[i]) + uncached_grid_thw.append(grid_thw[i]) + uncached_mm_hashes.append(mm_hash) + + uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64) + uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) + if len(uncached_mm_hashes) > 0: + uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64) + uncached_images = paddle.to_tensor( + np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16" + ) + uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64) + + return ( + uncached_input_ids, + uncached_token_type_ids, + uncached_image_type_ids, + uncached_images, + uncached_grid_thw, + uncached_mm_hashes, + ) + + def scatter_and_cache_features(self, image_features, inputs): + """ + Split batched image features and cache them + """ + merge_size = 2 + grid_thw = inputs["grid_thw"] + mm_hashes = inputs["mm_hashes"] + image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist() + image_features_lst = paddle.split(image_features, image_features_size, axis=0) + + assert len(image_features_lst) == len( + mm_hashes + ), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}" + for i, mm_hash in enumerate(mm_hashes): + self.encoder_cache[mm_hash] = image_features_lst[i].cpu() + + def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict): + """ + Apply multimodal inputs to share_inputs + - add image_features, extract and cache vision features from model + - add rope_emb, rotate position embeddings + """ + if self.encoder_cache: + evict_mm_hashes = request.get("evict_mm_hashes", None) + if evict_mm_hashes: + for mm_hash in evict_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + inputs = request.multimodal_inputs + if request.with_image: + if envs.FD_ENABLE_MAX_PREFILL: + multi_vision_inputs["images_lst"].append( + inputs["images"][request.image_start : request.image_end].cuda() + ) + multi_vision_inputs["grid_thw_lst"].extend( + inputs["grid_thw"][request.num_image_start : request.num_image_end] + ) + multi_vision_inputs["cu_seqlens"].extend( + inputs["vit_seqlen"][request.num_image_start : request.num_image_end] + ) + multi_vision_inputs["vit_position_ids_lst"].extend( + inputs["vit_position_ids"][request.num_image_start : request.num_image_end] + ) + else: + vision_inputs = inputs + if self.encoder_cache: + ( + vision_inputs["input_ids"], + vision_inputs["token_type_ids"], + vision_inputs["image_type_ids"], + vision_inputs["images"], + vision_inputs["grid_thw"], + vision_inputs["mm_hashes"], + ) = self.batch_uncached_inputs(request) + if len(vision_inputs["mm_hashes"]) > 0: + # uncached multimodal inputs exist + image_features = self.extract_vision_features(vision_inputs) + self.scatter_and_cache_features(image_features, vision_inputs) + + full_image_features_lst = [] + for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]: + feature = self.encoder_cache[mm_hash].cuda() + full_image_features_lst.append(feature) + image_features = paddle.concat(full_image_features_lst, axis=0) + else: + ( + input_ids, + token_type_ids, + image_type_ids, + images, + grid_thw, + mm_hashes, + ) = self.get_chunked_inputs(request) + vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64) + vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64) + vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64) + vision_inputs["images"] = paddle.to_tensor( + images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16" + ) + vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64) + vision_inputs["mm_hashes"] = mm_hashes + + image_features = self.extract_vision_features(vision_inputs) + + # part of the first image may be already cached + if "ernie" in self.model_config.model_type: + actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id) + elif "qwen" in self.model_config.model_type: + actual_image_token_num = paddle.sum( + vision_inputs["input_ids"] == vision_inputs["image_patch_id"] + ) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"]) + else: + raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") + self.share_inputs["image_features"] = image_features[-actual_image_token_num:] + else: + self.share_inputs["image_features"] = None + + position_ids = request.multimodal_inputs["position_ids"] + rope_3d_position_ids["position_ids_idx"].append(request.idx) + rope_3d_position_ids["position_ids_lst"].append(position_ids) + rope_3d_position_ids["position_ids_offset"].append( + position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1] + ) + rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) + def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None): """ Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 @@ -326,51 +510,7 @@ class GPUModelRunner(ModelRunnerBase): prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index if self.enable_mm: - inputs = request.multimodal_inputs - if request.with_image: - if envs.FD_ENABLE_MAX_PREFILL: - multi_vision_inputs["images_lst"].append( - inputs["images"][request.image_start : request.image_end].cuda() - ) - multi_vision_inputs["grid_thw_lst"].extend( - inputs["grid_thw"][request.num_image_start : request.num_image_end] - ) - multi_vision_inputs["cu_seqlens"].extend( - inputs["vit_seqlen"][request.num_image_start : request.num_image_end] - ) - multi_vision_inputs["vit_position_ids_lst"].extend( - inputs["vit_position_ids"][request.num_image_start : request.num_image_end] - ) - else: - vision_inputs = {} - vision_inputs["input_ids"] = paddle.to_tensor( - inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 - ) - vision_inputs["token_type_ids"] = paddle.to_tensor( - inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 - ) - vision_inputs["image_type_ids"] = paddle.to_tensor( - inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end], - dtype=paddle.int64, - ) - vision_inputs["images"] = paddle.to_tensor( - inputs["images"][request.image_start : request.image_end], - dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", - ) - vision_inputs["grid_thw"] = paddle.to_tensor( - inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" - ) - self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs) - else: - self.share_inputs["image_features"] = None - - position_ids = request.multimodal_inputs["position_ids"] - rope_3d_position_ids["position_ids_idx"].append(idx) - rope_3d_position_ids["position_ids_lst"].append(position_ids) - rope_3d_position_ids["position_ids_offset"].append( - position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1] - ) - rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) + self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids) if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: # Enable thinking diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 56716c925..00457966b 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -653,6 +653,12 @@ def parse_args(): help="Flag to specify dtype of lm_head as FP32", ) + parser.add_argument( + "--max_encoder_cache", + type=int, + help="Maximum encoder cache tokens(use 0 to disable).", + ) + parser.add_argument( "--cache-transfer-protocol", type=str, diff --git a/tests/input/test_qwen_vl_processor.py b/tests/input/test_qwen_vl_processor.py index 43e8c5542..a6db80cde 100644 --- a/tests/input/test_qwen_vl_processor.py +++ b/tests/input/test_qwen_vl_processor.py @@ -91,17 +91,18 @@ class TestQwenVLProcessor(unittest.TestCase): config.vision_config.tokens_per_second = 2 self.patcher_parse_image = patch( - "fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_image", return_value=mock_pil_image(480, 640) + "fastdeploy.entrypoints.chat_utils.MultimodalPartParser.parse_image", return_value=mock_pil_image(480, 640) ) self.patcher_parse_image.start() self.patcher_parse_video = patch( - "fastdeploy.entrypoints.chat_utils.MultiModalPartParser.parse_video", return_value=b"123" + "fastdeploy.entrypoints.chat_utils.MultimodalPartParser.parse_video", return_value=b"123" ) self.patcher_parse_video.start() self.patcher_read_frames = patch( - "fastdeploy.input.qwen_vl_processor.process.read_frames", return_value=mock_read_frames(480, 640, 5, 2) + "fastdeploy.input.qwen_vl_processor.process.DataProcessor._load_and_process_video", + return_value=mock_read_frames(480, 640, 5, 2), ) self.patcher_read_frames.start() @@ -163,8 +164,6 @@ class TestQwenVLProcessor(unittest.TestCase): self.assertEqual( result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum() ) - self.assertEqual(result.multimodal_inputs["pic_cnt"], 1) - self.assertEqual(result.multimodal_inputs["video_cnt"], 1) def test_process_request_dict(self): """ @@ -204,8 +203,6 @@ class TestQwenVLProcessor(unittest.TestCase): self.assertEqual( result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum() ) - self.assertEqual(result["multimodal_inputs"]["pic_cnt"], 1) - self.assertEqual(result["multimodal_inputs"]["video_cnt"], 1) def test_prompt(self): """ @@ -240,8 +237,6 @@ class TestQwenVLProcessor(unittest.TestCase): self.assertEqual( result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum() ) - self.assertEqual(result.multimodal_inputs["pic_cnt"], 1) - self.assertEqual(result.multimodal_inputs["video_cnt"], 1) def test_message_and_prompt(self): """ diff --git a/tests/multimodal/test_hasher.py b/tests/multimodal/test_hasher.py new file mode 100644 index 000000000..a89ff2cf1 --- /dev/null +++ b/tests/multimodal/test_hasher.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import pickle +import unittest + +import numpy as np + +from fastdeploy.multimodal.hasher import MultimodalHasher + + +class TestHashFeatures(unittest.TestCase): + def test_hash_features_ndarray(self): + """Test hash features with numpy ndarray""" + arr = np.random.randint(low=0, high=255, size=(28, 28), dtype=np.uint8) + arr_hash = MultimodalHasher.hash_features(arr) + target_hash = hashlib.sha256((arr.tobytes())).hexdigest() + assert arr_hash == target_hash, f"Ndarray hash mismatch: {arr_hash} != {target_hash}" + + def test_hash_features_object(self): + """Test hash features with unsupported object type""" + obj = {"key": "value"} + obj_hash = MultimodalHasher.hash_features(obj) + target_hash = hashlib.sha256((pickle.dumps(obj))).hexdigest() + assert obj_hash == target_hash, f"Dict hash mismatch: {obj_hash} != {target_hash}" + + obj = "test hasher str" + obj_hash = MultimodalHasher.hash_features(obj) + target_hash = hashlib.sha256((pickle.dumps(obj))).hexdigest() + assert obj_hash == target_hash, f"Str hash mismatch: {obj_hash} != {target_hash}" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/v1/cache_manager/test_encoder_cache.py b/tests/v1/cache_manager/test_encoder_cache.py new file mode 100644 index 000000000..8099cce49 --- /dev/null +++ b/tests/v1/cache_manager/test_encoder_cache.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastdeploy.cache_manager.multimodal_cache_manager import EncoderCacheManager +from fastdeploy.engine.request import ImagePosition + + +def test_mm_encoder_cache(): + max_encoder_cache = 4096 + encoder_cache = EncoderCacheManager(max_encoder_cache=max_encoder_cache) + + mm_hashes = ["mm_hash1", "mm_hash2"] + mm_positions = [ImagePosition(offset=120, length=400), ImagePosition(offset=620, length=800)] + + cache_length = mm_positions[0].length + mm_positions[1].length + evict_hashes = encoder_cache.apply_cache(mm_hashes=mm_hashes, mm_items=mm_positions) + assert evict_hashes == [], "The evicted hashes should be empty." + assert list(encoder_cache.cache.keys()) == [ + "mm_hash1", + "mm_hash2", + ], "The cache should contain mm_hash1 and mm_hash2." + assert ( + encoder_cache.current_cache_size == cache_length + ), "The cache size should be the sum of the lengths of mm_hash1 and mm_hash2." + assert ( + encoder_cache.current_cache_size <= max_encoder_cache + ), "The cache size should be less than or equal to the max_encoder_cache." + + mm_hashes = ["mm_hash3", "mm_hash4"] + mm_positions = [ImagePosition(offset=20, length=1204), ImagePosition(offset=1800, length=2048)] + cache_length += mm_positions[0].length + mm_positions[1].length - 400 + evict_hashes = encoder_cache.apply_cache(mm_hashes=mm_hashes, mm_items=mm_positions) + assert evict_hashes == ["mm_hash1"], "The evicted hashes should be mm_hash1." + assert list(encoder_cache.cache.keys()) == [ + "mm_hash2", + "mm_hash3", + "mm_hash4", + ], "The cache should contain mm_hash2, mm_hash3, and mm_hash4." + assert ( + encoder_cache.current_cache_size == cache_length + ), "The cache size should be the sum of the lengths of mm_hash2, mm_hash3, and mm_hash4." + assert ( + encoder_cache.current_cache_size <= max_encoder_cache + ), "The cache size should be less than or equal to the max_encoder_cache." + + evict_hashes = encoder_cache.apply_cache(mm_hashes=["mm_hash2"], mm_items=[ImagePosition(offset=620, length=800)]) + assert evict_hashes == [], "The evicted hashes should be empty." + assert ( + encoder_cache.current_cache_size == cache_length + ), "The cache size should be the sum of the lengths of mm_hash2, mm_hash3, and mm_hash4." + assert ( + encoder_cache.current_cache_size <= max_encoder_cache + ), "The cache size should be less than or equal to the max_encoder_cache." + + cache_length -= 1204 + evict_hashes = encoder_cache.evict_cache(needed=800) + assert evict_hashes == ["mm_hash3"], "The evicted hashes should be mm_hash3." + assert list(encoder_cache.cache.keys()) == [ + "mm_hash4", + "mm_hash2", + ], "The cache should contain mm_hash2 and mm_hash4." + assert ( + encoder_cache.current_cache_size == cache_length + ), "The cache size should be the sum of the lengths of mm_hash2 and mm_hash4." + assert ( + encoder_cache.current_cache_size <= max_encoder_cache + ), "The cache size should be less than or equal to the max_encoder_cache." + + encoder_cache.clear_cache() + assert encoder_cache.current_cache_size == 0, "The cache size should be 0." + assert list(encoder_cache.cache.keys()) == [], "The cache should be empty." diff --git a/tests/v1/cache_manager/test_prefix_cache.py b/tests/v1/cache_manager/test_prefix_cache.py new file mode 100644 index 000000000..f429f709b --- /dev/null +++ b/tests/v1/cache_manager/test_prefix_cache.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict +from types import SimpleNamespace + +from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager +from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig +from fastdeploy.engine.args_utils import EngineArgs +from fastdeploy.engine.request import ImagePosition, Request +from fastdeploy.scheduler import SchedulerConfig + + +def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_override=100, max_num_batched_tokens=3200): + engine_args = EngineArgs( + max_num_seqs=max_num_seqs, + num_gpu_blocks_override=num_gpu_blocks_override, + max_num_batched_tokens=max_num_batched_tokens, + ) + args = asdict(engine_args) + cache_cfg = CacheConfig(args) + model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=8192) + speculative_cfg = SimpleNamespace(method=None) + model_cfg.print = print + cache_cfg.bytes_per_layer_per_block = 1 + parallel_cfg = ParallelConfig(args) + scheduler_cfg = SchedulerConfig(args) + graph_opt_cfg = engine_args.create_graph_optimization_config() + fd_config = FDConfig( + model_config=model_cfg, + cache_config=cache_cfg, + parallel_config=parallel_cfg, + graph_opt_config=graph_opt_cfg, + speculative_config=speculative_cfg, + scheduler_config=scheduler_cfg, + ) + return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed") + + +def test_normal_case(): + block_size = 64 + cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=False, num_gpu_blocks_override=100) + req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) + req2 = Request.from_dict( + {"request_id": "req2", "prompt_token_ids": [1] * 1600 + [2] * 1600, "prompt_token_ids_len": 3200} + ) + req3 = Request.from_dict( + {"request_id": "req3", "prompt_token_ids": [1] * 1600 + [3] * 1600, "prompt_token_ids_len": 3200} + ) + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size) + assert len(common_block_ids) == 0 + assert matched_token_num == 0 + assert len(cache_manager.gpu_free_block_list) == 100 + req1.block_tables.extend(common_block_ids) + # allocate for req1 inputs + num_new_block = 50 + req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + req1.num_computed_tokens += 50 * block_size + cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens) + assert len(cache_manager.gpu_free_block_list) == 50 + # allocate for req2 inputs + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size) + assert len(common_block_ids) == 25 + assert matched_token_num == 25 * block_size + req2.num_cached_tokens = matched_token_num + req2.num_computed_tokens = 25 * block_size + num_new_block = 25 + req2.block_tables.extend(common_block_ids) + req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens) + # allocate for req3 input + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size) + assert len(common_block_ids) == 25 + assert matched_token_num == 25 * block_size + req3.num_cached_tokens = matched_token_num + req3.num_computed_tokens = 25 * block_size + assert len(cache_manager.gpu_free_block_list) == 25 + req3.block_tables.extend(common_block_ids) + num_new_block = 25 + assert cache_manager.can_allocate_gpu_blocks(num_new_block) + req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens) + assert len(cache_manager.gpu_free_block_list) == 0 + + +def test_mm_extra_keys(): + block_size = 64 + cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True) + + prompt_token_ids = [1] * 100 + [2] * 100 + req1 = { + "request_id": "req1", + "prompt_token_ids": prompt_token_ids, + "prompt_token_ids_len": len(prompt_token_ids), + } + for idx in range(0, len(prompt_token_ids), block_size): + token_ids_lens = min(block_size, len(prompt_token_ids[idx:])) + mm_idx, extra_keys = cache_manager.get_block_hash_extra_keys( + request=Request.from_dict(req1), + start_idx=idx, + end_idx=idx + token_ids_lens, + mm_idx=0, + ) + assert extra_keys == [], f"extra_keys {extra_keys} != [], start_idx {idx}, end_idx {idx + token_ids_lens}" + assert mm_idx == 0, f"mm_idx {mm_idx} != 0, start_idx {idx}, end_idx {idx + token_ids_lens}" + + # block 1 + prompt_token_ids = [1] * 30 + [-1] * 34 + mm_positions = [ImagePosition(offset=30, length=80)] + mm_hashes = ["image1"] + extra_keys_list = [(0, ["image1"])] + + # block 2 + prompt_token_ids += [-1] * 46 + [2] * 18 + extra_keys_list.append((1, ["image1"])) + + # block 3 + prompt_token_ids += [-1] * 100 + mm_positions.append(ImagePosition(offset=128, length=100)) + mm_hashes.append("image2") + extra_keys_list.append((1, ["image2"])) + + # block 4、5 + prompt_token_ids += [3] * 40 + extra_keys_list.append((1, ["image2"])) + extra_keys_list.append((1, [])) + + req2 = { + "request_id": "req2", + "prompt_token_ids": prompt_token_ids, + "prompt_token_ids_len": len(prompt_token_ids), + "multimodal_inputs": { + "mm_positions": mm_positions, + "mm_hashes": mm_hashes, + }, + } + + mm_idx, key_idx = 0, 0 + for idx in range(0, len(prompt_token_ids), block_size): + token_ids_lens = min(block_size, len(prompt_token_ids[idx:])) + mm_idx, extra_keys = cache_manager.get_block_hash_extra_keys( + request=Request.from_dict(req2), + start_idx=idx, + end_idx=idx + token_ids_lens, + mm_idx=mm_idx, + ) + + target_idx, target_keys = extra_keys_list[key_idx] + assert ( + mm_idx == target_idx + ), f"mm_idx {mm_idx} != target_idx {target_idx}, start_idx {idx}, end_idx {idx + token_ids_lens}" + assert ( + extra_keys == target_keys + ), f"extra_keys {extra_keys} != target_keys {target_keys}, start_idx {idx}, end_idx {idx + token_ids_lens}" + key_idx += 1 + + +def test_mm_prefix_cache(): + block_size = 64 + cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100) + multimodal_inputs = { + "mm_positions": [ImagePosition(offset=120, length=1200)], + "mm_hashes": ["image1"], + } + req1_dict = { + "request_id": "req1", + "prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120, + "prompt_token_ids_len": 1440, + "multimodal_inputs": multimodal_inputs, + } + req1 = Request.from_dict(req1_dict) + + multimodal_inputs = dict(multimodal_inputs) + multimodal_inputs["mm_positions"].append(ImagePosition(offset=1836, length=587)) + multimodal_inputs["mm_hashes"].append("image2") + req2_dict = { + "request_id": "req2", + "prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 396 + [-1] * 587, + "prompt_token_ids_len": 2423, + "multimodal_inputs": multimodal_inputs, + } + req2 = Request.from_dict(req2_dict) + + multimodal_inputs = dict(multimodal_inputs) + multimodal_inputs["mm_hashes"] = ["image3", "image4"] + req3_dict = { + "request_id": "req3", + "prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 396 + [-1] * 587, + "prompt_token_ids_len": 2423, + "multimodal_inputs": multimodal_inputs, + } + req3 = Request.from_dict(req3_dict) + + multimodal_inputs = dict(multimodal_inputs) + multimodal_inputs["mm_positions"] = [ImagePosition(offset=120, length=1200)] + multimodal_inputs["mm_hashes"] = ["image3"] + req4_dict = { + "request_id": "req4", + "prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 352, + "prompt_token_ids_len": 1792, + "multimodal_inputs": multimodal_inputs, + } + req4 = Request.from_dict(req4_dict) + + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size) + assert len(common_block_ids) == 0 + assert matched_token_num == 0 + assert len(cache_manager.gpu_free_block_list) == 100 + req1.block_tables.extend(common_block_ids) + + # allocate for req1 inputs + num_new_block = 22 + req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + req1.num_computed_tokens += 22 * block_size + cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens) + assert len(cache_manager.gpu_free_block_list) == 78 + + # allocate for req2 inputs + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size) + assert len(common_block_ids) == 22 + assert matched_token_num == 22 * block_size + req2.num_cached_tokens = matched_token_num + req2.num_computed_tokens = matched_token_num + num_new_block = 15 + req2.block_tables.extend(common_block_ids) + req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + req2.num_computed_tokens += 15 * block_size + cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens) + + # allocate for req3 input + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size) + assert len(common_block_ids) == 1 + assert matched_token_num == 1 * block_size + req3.num_cached_tokens = matched_token_num + req3.num_computed_tokens = matched_token_num + assert len(cache_manager.gpu_free_block_list) == 63 + req3.block_tables.extend(common_block_ids) + num_new_block = 36 + assert cache_manager.can_allocate_gpu_blocks(num_new_block) + req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + req3.num_computed_tokens += 36 * block_size + cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens) + assert len(cache_manager.gpu_free_block_list) == 27 + + # allocate for req4 input + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req4, block_size) + assert len(common_block_ids) == 28 + assert matched_token_num == 28 * block_size diff --git a/tests/v1/test_prefix_cache.py b/tests/v1/test_prefix_cache.py deleted file mode 100644 index a6c7c2bf9..000000000 --- a/tests/v1/test_prefix_cache.py +++ /dev/null @@ -1,73 +0,0 @@ -from dataclasses import asdict -from types import SimpleNamespace - -from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager -from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig -from fastdeploy.engine.args_utils import EngineArgs -from fastdeploy.engine.request import Request - - -def test_normal_case(): - max_num_seqs = 3 - block_size = 64 - engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=100, max_num_batched_tokens=3200) - args = asdict(engine_args) - cache_cfg = CacheConfig(args) - model_cfg = SimpleNamespace(enable_mm=False) - speculative_cfg = SimpleNamespace(method=None) - model_cfg.print = print - model_cfg.max_model_len = 5120 - cache_cfg.bytes_per_layer_per_block = 1 - parallel_cfg = ParallelConfig(args) - scheduler_cfg = SchedulerConfig(args) - graph_opt_cfg = engine_args.create_graph_optimization_config() - fd_config = FDConfig( - model_config=model_cfg, - cache_config=cache_cfg, - parallel_config=parallel_cfg, - graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, - scheduler_config=scheduler_cfg, - ) - cache_manager = PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed") - req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) - req2 = Request.from_dict( - {"request_id": "req2", "prompt_token_ids": [1] * 1600 + [2] * 1600, "prompt_token_ids_len": 3200} - ) - req3 = Request.from_dict( - {"request_id": "req3", "prompt_token_ids": [1] * 1600 + [3] * 1600, "prompt_token_ids_len": 3200} - ) - (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size) - assert len(common_block_ids) == 0 - assert matched_token_num == 0 - assert len(cache_manager.gpu_free_block_list) == 100 - req1.block_tables.extend(common_block_ids) - # allocate for req1 inputs - num_new_block = 50 - req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) - req1.num_computed_tokens += 50 * block_size - cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens) - assert len(cache_manager.gpu_free_block_list) == 50 - # allocate for req2 inputs - (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size) - assert len(common_block_ids) == 25 - assert matched_token_num == 25 * block_size - req2.num_cached_tokens = matched_token_num - req2.num_computed_tokens == 25 * block_size - num_new_block = 25 - req2.block_tables.extend(common_block_ids) - req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) - cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens) - # allocate for req3 input - (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size) - assert len(common_block_ids) == 25 - assert matched_token_num == 25 * block_size - req3.num_cached_tokens = matched_token_num - req3.num_computed_tokens == 25 * block_size - assert len(cache_manager.gpu_free_block_list) == 25 - req3.block_tables.extend(common_block_ids) - num_new_block = 25 - assert cache_manager.can_allocate_gpu_blocks(num_new_block) - req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) - cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens) - assert len(cache_manager.gpu_free_block_list) == 0