[Feature] mm support prefix cache (#4134)

* support mm prefix caching

* update code

* fix mm_hashes

* support encoder cache

* add encoder cache

* update code

* update encoder cache

* fix features bug

* fix worker bug

* support processor cache, need to optimize yet

* refactor multimodal data cache

* update code

* update code

* update v1 scheduler

* update code

* update code

* update codestyle

* support turn off processor cache and encoder cache

* update pre-commit

* fix code

* solve review

* update code

* update code

* update test case

* set processor cache in GiB

* update test case

* support mm prefix caching for qwen model

* fix code style check

* update pre-commit

* fix unit test

* fix unit test

* add ci test case

* fix rescheduled bug

* change text_after_process to prompt_tokens

* fix unit test

* fix chat template

* change model path

* [EP] fix adapter bugs (#4572)

* Update expert_service.py

* Update common_engine.py

* Update expert_service.py

* fix v1 hang bug (#4573)

* fix import image_ops error on some platforms (#4559)

* [CLI]Update parameters in bench latecy cli tool and fix collect-env cli tool (#4558)

* add collect-env

* del files

* [Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578)

* add new branch for sot

* reorder

* fix batch bug

* [XPU]Moe uses a new operator (#4585)

* [XPU]Moe uses a new operator

* [XPU]Moe uses a new operator

* update response

* [Feature] Support Paddle-OCR (#4396)

* init

* update code

* fix code style & disable thinking

* adapt for common_engine.update_mm_requests_chunk_size

* use 3d rope

* use flash_attn_unpadded

* opt siglip

* update to be compatible with the latest codebase

* fix typo

* optim OCR performance

* fix bug

* fix bug

* fix bug

* fix bug

* normlize name

* modify xpu rope

* revert logger

* fix bug

* fix bug

* fix bug

* support default_v1

* optim performance

* fix bug

---------

Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>

* [DataProcessor] add reasoning_tokens into usage info (#4520)

* add reasoning_tokens into usage info initial commit

* add unit tests

* modify unit test

* modify and add unit tests

* fix unit test

* move steam usage to processor

* modify processor

* modify test_logprobs

* modify test_logprobs.py

* modify stream reasoning tokens accumulation

* fix unit test

* perf: Optimize task queue communication from engine to worker (#4531)

* perf: Optimize task queue communication from engine to worker

* perf: get_tasks to numpy

* perf: get_tasks remove to_numpy

* fix: request & replace ENV

* remove test_e2w_perf.py

* fix code style

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* Clean up ports after processing results (#4587)

* [CI] Add /re-run command in PR comments to restart failed CI workflows (#4593)

* [Others] api server exits when worker process is dead (#3271)

* [fix] fix terminal hangs when worker process is dead

* [chore] change sleep time of monitor

* [chore] remove redundant comments

* update docs

---------

Co-authored-by: ApplEOFDiscord <wwy640130@163.com>
Co-authored-by: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
Co-authored-by: yinwei <yinwei_hust@163.com>
Co-authored-by: JYChen <zoooo0820@qq.com>
Co-authored-by: qwes5s5 <45442318+qwes5s5@users.noreply.github.com>
Co-authored-by: Ryan <zihaohuang@aliyun.com>
Co-authored-by: yyssys <atyangshuang@foxmail.com>
Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com>
Co-authored-by: zhangyue66 <zhangyue66@baidu.com>
Co-authored-by: kxz2002 <115912648+kxz2002@users.noreply.github.com>
Co-authored-by: SunLei <sunlei5788@gmail.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
This commit is contained in:
kevin
2025-10-27 17:39:51 +08:00
committed by GitHub
parent a4fb3d4ff0
commit 8aab4e367f
40 changed files with 1741 additions and 545 deletions

View File

@@ -14,8 +14,10 @@
# limitations under the License.
"""
import hashlib
import heapq
import os
import pickle
import subprocess
import sys
import threading
@@ -35,7 +37,7 @@ from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTre
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
logger = get_logger("prefix_cache_manager", "cache_manager.log")
class PrefixCacheManager:
@@ -575,31 +577,18 @@ class PrefixCacheManager:
"""
try:
req_id = task.request_id
block_tables = task.block_tables
last_node, num_cached_tokens = self.cache_info[req_id]
if isinstance(task.prompt_token_ids, np.ndarray):
prompt_token_ids = task.prompt_token_ids.tolist()
else:
prompt_token_ids = task.prompt_token_ids
input_ids = prompt_token_ids + task.output_token_ids
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
self.leaf_req_map[last_node].remove(req_id)
with self.request_release_lock:
current_time = time.time()
leaf_node = self.build_path(
req_id=req_id,
current_time=current_time,
input_ids=input_ids,
left_input_ids=left_input_ids,
gpu_block_ids=gpu_extra_block_ids,
leaf_node = self.mm_build_path(
request=task,
num_computed_tokens=num_computed_tokens,
block_size=block_size,
last_node=last_node,
reverved_dec_block_num=0,
num_cached_tokens=num_cached_tokens,
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
@@ -636,10 +625,9 @@ class PrefixCacheManager:
prompt_token_ids = task.prompt_token_ids.tolist()
else:
prompt_token_ids = task.prompt_token_ids
input_ids = prompt_token_ids + task.output_token_ids
req_id = task.request_id
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
input_token_num = len(input_ids)
input_token_num = len(prompt_token_ids + task.output_token_ids)
common_block_ids = []
# 1. match block
(
@@ -649,7 +637,7 @@ class PrefixCacheManager:
match_block_node,
gpu_match_token_num,
cpu_match_token_num,
) = self.match_block(req_id, input_ids, block_size)
) = self.mm_match_block(task, block_size)
# update matched node info
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
@@ -1145,6 +1133,173 @@ class PrefixCacheManager:
"""
return hash(tuple(block))
def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
"""
Retrieves additional hash keys for block identification.
Args:
request: The input request object containing the data to be processed.
start_idx (int): The starting index of the block segment to hash.
end_idx (int): The ending index of the block segment to hash.
mm_idx: The multimodal index identifier for specialized content handling.
Returns:
mm_idx: next multimodal index
hash_keys: A list of additional hash keys
"""
hash_keys = []
mm_inputs = request.multimodal_inputs
if (
mm_inputs is None
or "mm_positions" not in mm_inputs
or "mm_hashes" not in mm_inputs
or len(mm_inputs["mm_positions"]) == 0
):
return mm_idx, hash_keys
assert start_idx < end_idx, f"start_idx {start_idx} >= end_idx {end_idx}"
assert (
start_idx >= 0 and start_idx < request.num_total_tokens
), f"start_idx {start_idx} out of range {request.num_total_tokens}"
assert (
end_idx >= 0 and end_idx <= request.num_total_tokens
), f"end_idx {end_idx} out of range {request.num_total_tokens}"
assert len(mm_inputs["mm_positions"]) == len(
mm_inputs["mm_hashes"]
), f"mm_positions {len(mm_inputs['mm_positions'])} != mm_hashes {len(mm_inputs['mm_hashes'])}"
assert mm_idx >= 0 and mm_idx < len(
mm_inputs["mm_hashes"]
), f"mm_idx {mm_idx} out of range {len(mm_inputs['mm_hashes'])}"
if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length < start_idx:
# non images in current block
return mm_idx, hash_keys
for img_idx in range(mm_idx, len(mm_inputs["mm_positions"])):
image_offset = mm_inputs["mm_positions"][img_idx].offset
image_length = mm_inputs["mm_positions"][img_idx].length
if image_offset + image_length < start_idx:
# image before block
continue
elif image_offset >= end_idx:
# image after block
return img_idx, hash_keys
elif image_offset + image_length > end_idx:
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
return img_idx, hash_keys
else:
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
return len(mm_inputs["mm_positions"]) - 1, hash_keys
def hash_block_features(self, input_ids, extra_keys: list = []):
"""
calculate hash value of a block with additional keys
Args:
input_ids: Input token IDs
extra_keys: Additional keys for block identification
"""
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
def mm_match_block(self, request, block_size):
"""
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
Args:
request: The multimodal request object containing prompt and output token IDs.
block_size (int): The size of each token block for matching and processing.
Returns:
tuple: A tuple containing:
- match_gpu_block_ids (list): List of block IDs matched in GPU cache
- match_cpu_block_ids (list): List of block IDs matched in CPU cache
- swap_node_ids (list): List of node IDs scheduled for GPU-CPU swapping
- current_match_node: The last matched node in the radix tree traversal
- gpu_match_token_num (int): Total number of tokens matched in GPU cache
- cpu_match_token_num (int): Total number of tokens matched in CPU cache
"""
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
input_ids = prompt_token_ids + request.output_token_ids
total_token_num = len(input_ids)
current_match_node = self.radix_tree_root # 从根节点开始搜
match_gpu_block_ids = []
match_cpu_block_ids = []
match_node_ids = []
mm_idx = 0
match_token_num = 0
cpu_match_token_num = 0
gpu_match_token_num = 0
swap_node_ids = []
matche_nodes = []
has_modified_gpu_lru_leaf_heap = False
has_modified_cpu_lru_leaf_heap = False
with self.cache_status_lock:
while match_token_num < total_token_num:
token_block = input_ids[match_token_num : match_token_num + block_size]
token_num = len(token_block)
if token_num != block_size:
break
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=request,
start_idx=match_token_num,
end_idx=match_token_num + block_size,
mm_idx=mm_idx,
)
hash_value = self.hash_block_features(token_block, extra_keys)
if hash_value in current_match_node.children:
child = current_match_node.children[hash_value]
matche_nodes.append(child)
match_node_ids.append(child.node_id)
if child in self.gpu_lru_leaf_set:
self.gpu_lru_leaf_set.remove(child)
self.gpu_lru_leaf_heap.remove(child)
has_modified_gpu_lru_leaf_heap = True
elif child in self.cpu_lru_leaf_set:
self.cpu_lru_leaf_set.remove(child)
self.cpu_lru_leaf_heap.remove(child)
has_modified_cpu_lru_leaf_heap = True
if child.has_in_gpu:
match_gpu_block_ids.append(child.block_id)
gpu_match_token_num += block_size
else:
if child.cache_status == CacheStatus.SWAP2CPU:
logger.info(
f"match_block: req_id {request.request_id} matched node"
+ f" {child.node_id} which is being SWAP2CPU"
)
child.cache_status = CacheStatus.GPU
match_gpu_block_ids.append(child.block_id)
gpu_match_token_num += block_size
elif child.cache_status == CacheStatus.CPU:
child.cache_status = CacheStatus.SWAP2GPU
match_cpu_block_ids.append(child.block_id)
cpu_match_token_num += block_size
swap_node_ids.append(child.node_id)
match_token_num = match_token_num + block_size
current_match_node = child
else:
break
if has_modified_gpu_lru_leaf_heap:
heapq.heapify(self.gpu_lru_leaf_heap)
if has_modified_cpu_lru_leaf_heap:
heapq.heapify(self.cpu_lru_leaf_heap)
logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
return (
match_gpu_block_ids,
match_cpu_block_ids,
swap_node_ids,
current_match_node,
gpu_match_token_num,
cpu_match_token_num,
)
def match_block(self, req_id, input_ids, block_size):
"""
Args:
@@ -1241,6 +1396,86 @@ class PrefixCacheManager:
node.req_id_set.add(req_id)
node = node.parent
def mm_build_path(self, request, num_computed_tokens, block_size, last_node, num_cached_tokens):
"""
Constructs a caching path in radix tree for multimodal requests by processing computed tokens.
Args:
request: The inference request object containing:
- prompt_token_ids: Original input tokens (List[int] or np.ndarray)
- output_token_ids: Generated tokens (List[int])
- mm_positions: Optional image positions for multimodal content
num_computed_tokens: Total tokens processed so far (cached + newly computed)
block_size: Fixed size of token blocks (must match cache configuration)
last_node: The deepest existing BlockNode in the radix tree for this request
num_cached_tokens: Number of tokens already cached
Returns:
BlockNode: The new deepest node in the constructed path
"""
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
input_ids = prompt_token_ids + request.output_token_ids
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
if num_cached_tokens == can_cache_computed_tokens:
return last_node
mm_idx = 0
node = last_node
unique_node_ids = []
new_last_node = last_node
has_unfilled_block = False
current_time = time.time()
input_hash_value = self.hash_block_features(input_ids)
gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy()
for i in range(num_cached_tokens, can_cache_computed_tokens, block_size):
current_block = input_ids[i : i + block_size]
current_block_size = len(current_block) # 最后一个block可能没填满
if current_block_size != block_size:
has_unfilled_block = True
else:
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=request,
start_idx=i,
end_idx=i + block_size,
mm_idx=mm_idx,
)
hash_value = self.hash_block_features(current_block, extra_keys)
allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id)
new_last_node = BlockNode(
node_id,
input_ids,
input_hash_value,
node.depth + 1,
allocated_block_id,
current_block_size,
hash_value,
current_time,
parent=node,
shared_count=1,
reverved_dec_block_ids=[],
)
new_last_node.req_id_set.add(request.request_id)
self.node_map[node_id] = new_last_node
node.children[hash_value] = new_last_node
node = new_last_node
reverved_dec_block_ids = []
if has_unfilled_block is True:
reverved_dec_block_ids.append(gpu_block_ids.pop(0))
if new_last_node == self.radix_tree_root:
self.unfilled_req_block_map[request.request_id] = reverved_dec_block_ids
else:
new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids)
logger.info(f"build_path: allocate unique node ids {unique_node_ids} for req_id {request.request_id}")
return new_last_node
def build_path(
self,
req_id,