mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] mm support prefix cache (#4134)
* support mm prefix caching * update code * fix mm_hashes * support encoder cache * add encoder cache * update code * update encoder cache * fix features bug * fix worker bug * support processor cache, need to optimize yet * refactor multimodal data cache * update code * update code * update v1 scheduler * update code * update code * update codestyle * support turn off processor cache and encoder cache * update pre-commit * fix code * solve review * update code * update code * update test case * set processor cache in GiB * update test case * support mm prefix caching for qwen model * fix code style check * update pre-commit * fix unit test * fix unit test * add ci test case * fix rescheduled bug * change text_after_process to prompt_tokens * fix unit test * fix chat template * change model path * [EP] fix adapter bugs (#4572) * Update expert_service.py * Update common_engine.py * Update expert_service.py * fix v1 hang bug (#4573) * fix import image_ops error on some platforms (#4559) * [CLI]Update parameters in bench latecy cli tool and fix collect-env cli tool (#4558) * add collect-env * del files * [Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578) * add new branch for sot * reorder * fix batch bug * [XPU]Moe uses a new operator (#4585) * [XPU]Moe uses a new operator * [XPU]Moe uses a new operator * update response * [Feature] Support Paddle-OCR (#4396) * init * update code * fix code style & disable thinking * adapt for common_engine.update_mm_requests_chunk_size * use 3d rope * use flash_attn_unpadded * opt siglip * update to be compatible with the latest codebase * fix typo * optim OCR performance * fix bug * fix bug * fix bug * fix bug * normlize name * modify xpu rope * revert logger * fix bug * fix bug * fix bug * support default_v1 * optim performance * fix bug --------- Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com> * [DataProcessor] add reasoning_tokens into usage info (#4520) * add reasoning_tokens into usage info initial commit * add unit tests * modify unit test * modify and add unit tests * fix unit test * move steam usage to processor * modify processor * modify test_logprobs * modify test_logprobs.py * modify stream reasoning tokens accumulation * fix unit test * perf: Optimize task queue communication from engine to worker (#4531) * perf: Optimize task queue communication from engine to worker * perf: get_tasks to numpy * perf: get_tasks remove to_numpy * fix: request & replace ENV * remove test_e2w_perf.py * fix code style --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> * Clean up ports after processing results (#4587) * [CI] Add /re-run command in PR comments to restart failed CI workflows (#4593) * [Others] api server exits when worker process is dead (#3271) * [fix] fix terminal hangs when worker process is dead * [chore] change sleep time of monitor * [chore] remove redundant comments * update docs --------- Co-authored-by: ApplEOFDiscord <wwy640130@163.com> Co-authored-by: ApplEOFDiscord <31272106+ApplEOFDiscord@users.noreply.github.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com> Co-authored-by: yinwei <yinwei_hust@163.com> Co-authored-by: JYChen <zoooo0820@qq.com> Co-authored-by: qwes5s5 <45442318+qwes5s5@users.noreply.github.com> Co-authored-by: Ryan <zihaohuang@aliyun.com> Co-authored-by: yyssys <atyangshuang@foxmail.com> Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com> Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com> Co-authored-by: kxz2002 <115912648+kxz2002@users.noreply.github.com> Co-authored-by: SunLei <sunlei5788@gmail.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,7 @@ from enum import Enum
|
||||
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
|
||||
DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = {
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
|
||||
class CacheMetrics:
|
||||
|
||||
163
fastdeploy/cache_manager/multimodal_cache_manager.py
Normal file
163
fastdeploy/cache_manager/multimodal_cache_manager.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import ImagePosition
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
|
||||
class MultimodalLRUCache(ABC):
|
||||
"""
|
||||
General lru cache for multimodal data
|
||||
"""
|
||||
|
||||
def __init__(self, max_cache_size):
|
||||
self.cache = OrderedDict()
|
||||
self.current_cache_size = 0
|
||||
self.max_cache_size = max_cache_size
|
||||
|
||||
def apply_cache(self, mm_hashes: list[str], mm_items: list[Any]) -> list[str]:
|
||||
"""
|
||||
apply data cache, return evicted data
|
||||
"""
|
||||
assert len(mm_hashes) == len(mm_items), "mm_hashes and mm_items should have same length"
|
||||
|
||||
evicted_hashes = []
|
||||
for idx in range(len(mm_hashes)):
|
||||
if mm_hashes[idx] in self.cache:
|
||||
self.cache.move_to_end(mm_hashes[idx])
|
||||
else:
|
||||
item_size = self.get_item_size(mm_items[idx])
|
||||
if self.current_cache_size + item_size >= self.max_cache_size:
|
||||
if item_size > self.max_cache_size:
|
||||
# cannot be inserted even if we clear all cached data, skip it directly
|
||||
continue
|
||||
needed = item_size - (self.max_cache_size - self.current_cache_size)
|
||||
evicted_hashes.extend(self.evict_cache(needed))
|
||||
self.cache[mm_hashes[idx]] = mm_items[idx]
|
||||
self.current_cache_size += item_size
|
||||
|
||||
return evicted_hashes
|
||||
|
||||
def evict_cache(self, needed: int) -> list[str]:
|
||||
"""
|
||||
evict data cache with needed size
|
||||
"""
|
||||
reduced_size, evicted_hashes = 0, []
|
||||
while reduced_size < needed and len(self.cache):
|
||||
mm_hash, mm_item = self.cache.popitem(last=False)
|
||||
evicted_hashes.append(mm_hash)
|
||||
reduced_size += self.get_item_size(mm_item)
|
||||
self.current_cache_size -= self.get_item_size(mm_item)
|
||||
|
||||
return evicted_hashes
|
||||
|
||||
def get_cache(self, mm_hashes: list[str]) -> list[Any]:
|
||||
"""
|
||||
get cached data correspond to given hash values
|
||||
"""
|
||||
mm_items = []
|
||||
for mm_hash in mm_hashes:
|
||||
if mm_hash not in self.cache:
|
||||
mm_items.append(None)
|
||||
continue
|
||||
mm_items.append(self.cache[mm_hash])
|
||||
|
||||
return mm_items
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
clear all cached data
|
||||
"""
|
||||
evicted_hashes = list(self.cache.keys())
|
||||
self.cache.clear()
|
||||
self.current_cache_size = 0
|
||||
|
||||
return evicted_hashes
|
||||
|
||||
@abstractmethod
|
||||
def get_item_size(self, item: Any) -> int:
|
||||
raise NotImplementedError("Subclasses must define how to get size of an item")
|
||||
|
||||
|
||||
class EncoderCacheManager(MultimodalLRUCache):
|
||||
"""
|
||||
EncoderCacheManager is used to cache image features
|
||||
"""
|
||||
|
||||
def __init__(self, max_encoder_cache):
|
||||
super().__init__(max_encoder_cache)
|
||||
|
||||
def get_item_size(self, item: ImagePosition) -> int:
|
||||
return item.length
|
||||
|
||||
|
||||
class ProcessorCacheManager(MultimodalLRUCache):
|
||||
"""
|
||||
ProcessorCacheManager is used to cache processed data
|
||||
"""
|
||||
|
||||
def __init__(self, max_processor_cache):
|
||||
super().__init__(max_processor_cache)
|
||||
|
||||
self.context = zmq.Context()
|
||||
|
||||
self.router = self.context.socket(zmq.ROUTER)
|
||||
self.router.setsockopt(zmq.SNDHWM, int(envs.FD_ZMQ_SNDHWM))
|
||||
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.router.bind("ipc:///dev/shm/processor_cache.ipc")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router, zmq.POLLIN)
|
||||
|
||||
self.handler_thread = threading.Thread(target=self.cache_request_handler, daemon=True)
|
||||
self.handler_thread.start()
|
||||
|
||||
def get_item_size(self, item: Tuple[np.ndarray, dict]) -> int:
|
||||
return item[0].nbytes
|
||||
|
||||
def cache_request_handler(self):
|
||||
try:
|
||||
while True:
|
||||
events = dict(self.poller.poll())
|
||||
|
||||
if self.router in events:
|
||||
client, _, content = self.router.recv_multipart()
|
||||
req = pickle.loads(content)
|
||||
|
||||
if isinstance(req, tuple):
|
||||
# apply cache request, in format of (mm_hashes, mm_items)
|
||||
self.apply_cache(req[0], req[1])
|
||||
logger.info(f"Apply processor cache of mm_hashes: {req[0]}")
|
||||
else:
|
||||
# get cache request
|
||||
resp = self.get_cache(req)
|
||||
logger.info(f"Get processor cache of mm_hashes: {req}")
|
||||
self.router.send_multipart([client, b"", pickle.dumps(resp)])
|
||||
except Exception as e:
|
||||
logger.error(f"Error happened while handling processor cache request: {e}")
|
||||
@@ -14,8 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import heapq
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
@@ -35,7 +37,7 @@ from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTre
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
|
||||
class PrefixCacheManager:
|
||||
@@ -575,31 +577,18 @@ class PrefixCacheManager:
|
||||
"""
|
||||
try:
|
||||
req_id = task.request_id
|
||||
block_tables = task.block_tables
|
||||
|
||||
last_node, num_cached_tokens = self.cache_info[req_id]
|
||||
if isinstance(task.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = task.prompt_token_ids
|
||||
input_ids = prompt_token_ids + task.output_token_ids
|
||||
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
|
||||
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
|
||||
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
|
||||
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
|
||||
self.leaf_req_map[last_node].remove(req_id)
|
||||
|
||||
with self.request_release_lock:
|
||||
current_time = time.time()
|
||||
leaf_node = self.build_path(
|
||||
req_id=req_id,
|
||||
current_time=current_time,
|
||||
input_ids=input_ids,
|
||||
left_input_ids=left_input_ids,
|
||||
gpu_block_ids=gpu_extra_block_ids,
|
||||
leaf_node = self.mm_build_path(
|
||||
request=task,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
block_size=block_size,
|
||||
last_node=last_node,
|
||||
reverved_dec_block_num=0,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
)
|
||||
self.req_leaf_map[req_id] = leaf_node
|
||||
self.leaf_req_map[leaf_node].add(req_id)
|
||||
@@ -636,10 +625,9 @@ class PrefixCacheManager:
|
||||
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = task.prompt_token_ids
|
||||
input_ids = prompt_token_ids + task.output_token_ids
|
||||
req_id = task.request_id
|
||||
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
|
||||
input_token_num = len(input_ids)
|
||||
input_token_num = len(prompt_token_ids + task.output_token_ids)
|
||||
common_block_ids = []
|
||||
# 1. match block
|
||||
(
|
||||
@@ -649,7 +637,7 @@ class PrefixCacheManager:
|
||||
match_block_node,
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
) = self.match_block(req_id, input_ids, block_size)
|
||||
) = self.mm_match_block(task, block_size)
|
||||
|
||||
# update matched node info
|
||||
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
|
||||
@@ -1145,6 +1133,173 @@ class PrefixCacheManager:
|
||||
"""
|
||||
return hash(tuple(block))
|
||||
|
||||
def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
|
||||
"""
|
||||
Retrieves additional hash keys for block identification.
|
||||
|
||||
Args:
|
||||
request: The input request object containing the data to be processed.
|
||||
start_idx (int): The starting index of the block segment to hash.
|
||||
end_idx (int): The ending index of the block segment to hash.
|
||||
mm_idx: The multimodal index identifier for specialized content handling.
|
||||
|
||||
Returns:
|
||||
mm_idx: next multimodal index
|
||||
hash_keys: A list of additional hash keys
|
||||
"""
|
||||
hash_keys = []
|
||||
mm_inputs = request.multimodal_inputs
|
||||
if (
|
||||
mm_inputs is None
|
||||
or "mm_positions" not in mm_inputs
|
||||
or "mm_hashes" not in mm_inputs
|
||||
or len(mm_inputs["mm_positions"]) == 0
|
||||
):
|
||||
return mm_idx, hash_keys
|
||||
|
||||
assert start_idx < end_idx, f"start_idx {start_idx} >= end_idx {end_idx}"
|
||||
assert (
|
||||
start_idx >= 0 and start_idx < request.num_total_tokens
|
||||
), f"start_idx {start_idx} out of range {request.num_total_tokens}"
|
||||
assert (
|
||||
end_idx >= 0 and end_idx <= request.num_total_tokens
|
||||
), f"end_idx {end_idx} out of range {request.num_total_tokens}"
|
||||
assert len(mm_inputs["mm_positions"]) == len(
|
||||
mm_inputs["mm_hashes"]
|
||||
), f"mm_positions {len(mm_inputs['mm_positions'])} != mm_hashes {len(mm_inputs['mm_hashes'])}"
|
||||
assert mm_idx >= 0 and mm_idx < len(
|
||||
mm_inputs["mm_hashes"]
|
||||
), f"mm_idx {mm_idx} out of range {len(mm_inputs['mm_hashes'])}"
|
||||
|
||||
if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length < start_idx:
|
||||
# non images in current block
|
||||
return mm_idx, hash_keys
|
||||
|
||||
for img_idx in range(mm_idx, len(mm_inputs["mm_positions"])):
|
||||
image_offset = mm_inputs["mm_positions"][img_idx].offset
|
||||
image_length = mm_inputs["mm_positions"][img_idx].length
|
||||
|
||||
if image_offset + image_length < start_idx:
|
||||
# image before block
|
||||
continue
|
||||
elif image_offset >= end_idx:
|
||||
# image after block
|
||||
return img_idx, hash_keys
|
||||
elif image_offset + image_length > end_idx:
|
||||
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
|
||||
return img_idx, hash_keys
|
||||
else:
|
||||
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
|
||||
return len(mm_inputs["mm_positions"]) - 1, hash_keys
|
||||
|
||||
def hash_block_features(self, input_ids, extra_keys: list = []):
|
||||
"""
|
||||
calculate hash value of a block with additional keys
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
extra_keys: Additional keys for block identification
|
||||
"""
|
||||
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
|
||||
|
||||
def mm_match_block(self, request, block_size):
|
||||
"""
|
||||
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
|
||||
|
||||
Args:
|
||||
request: The multimodal request object containing prompt and output token IDs.
|
||||
block_size (int): The size of each token block for matching and processing.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- match_gpu_block_ids (list): List of block IDs matched in GPU cache
|
||||
- match_cpu_block_ids (list): List of block IDs matched in CPU cache
|
||||
- swap_node_ids (list): List of node IDs scheduled for GPU-CPU swapping
|
||||
- current_match_node: The last matched node in the radix tree traversal
|
||||
- gpu_match_token_num (int): Total number of tokens matched in GPU cache
|
||||
- cpu_match_token_num (int): Total number of tokens matched in CPU cache
|
||||
"""
|
||||
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = request.prompt_token_ids
|
||||
input_ids = prompt_token_ids + request.output_token_ids
|
||||
total_token_num = len(input_ids)
|
||||
current_match_node = self.radix_tree_root # 从根节点开始搜
|
||||
match_gpu_block_ids = []
|
||||
match_cpu_block_ids = []
|
||||
match_node_ids = []
|
||||
mm_idx = 0
|
||||
match_token_num = 0
|
||||
cpu_match_token_num = 0
|
||||
gpu_match_token_num = 0
|
||||
swap_node_ids = []
|
||||
matche_nodes = []
|
||||
has_modified_gpu_lru_leaf_heap = False
|
||||
has_modified_cpu_lru_leaf_heap = False
|
||||
|
||||
with self.cache_status_lock:
|
||||
while match_token_num < total_token_num:
|
||||
token_block = input_ids[match_token_num : match_token_num + block_size]
|
||||
token_num = len(token_block)
|
||||
if token_num != block_size:
|
||||
break
|
||||
mm_idx, extra_keys = self.get_block_hash_extra_keys(
|
||||
request=request,
|
||||
start_idx=match_token_num,
|
||||
end_idx=match_token_num + block_size,
|
||||
mm_idx=mm_idx,
|
||||
)
|
||||
hash_value = self.hash_block_features(token_block, extra_keys)
|
||||
if hash_value in current_match_node.children:
|
||||
child = current_match_node.children[hash_value]
|
||||
matche_nodes.append(child)
|
||||
match_node_ids.append(child.node_id)
|
||||
if child in self.gpu_lru_leaf_set:
|
||||
self.gpu_lru_leaf_set.remove(child)
|
||||
self.gpu_lru_leaf_heap.remove(child)
|
||||
has_modified_gpu_lru_leaf_heap = True
|
||||
elif child in self.cpu_lru_leaf_set:
|
||||
self.cpu_lru_leaf_set.remove(child)
|
||||
self.cpu_lru_leaf_heap.remove(child)
|
||||
has_modified_cpu_lru_leaf_heap = True
|
||||
if child.has_in_gpu:
|
||||
match_gpu_block_ids.append(child.block_id)
|
||||
gpu_match_token_num += block_size
|
||||
else:
|
||||
if child.cache_status == CacheStatus.SWAP2CPU:
|
||||
logger.info(
|
||||
f"match_block: req_id {request.request_id} matched node"
|
||||
+ f" {child.node_id} which is being SWAP2CPU"
|
||||
)
|
||||
child.cache_status = CacheStatus.GPU
|
||||
match_gpu_block_ids.append(child.block_id)
|
||||
gpu_match_token_num += block_size
|
||||
elif child.cache_status == CacheStatus.CPU:
|
||||
child.cache_status = CacheStatus.SWAP2GPU
|
||||
match_cpu_block_ids.append(child.block_id)
|
||||
cpu_match_token_num += block_size
|
||||
swap_node_ids.append(child.node_id)
|
||||
match_token_num = match_token_num + block_size
|
||||
current_match_node = child
|
||||
else:
|
||||
break
|
||||
|
||||
if has_modified_gpu_lru_leaf_heap:
|
||||
heapq.heapify(self.gpu_lru_leaf_heap)
|
||||
if has_modified_cpu_lru_leaf_heap:
|
||||
heapq.heapify(self.cpu_lru_leaf_heap)
|
||||
|
||||
logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
|
||||
return (
|
||||
match_gpu_block_ids,
|
||||
match_cpu_block_ids,
|
||||
swap_node_ids,
|
||||
current_match_node,
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
)
|
||||
|
||||
def match_block(self, req_id, input_ids, block_size):
|
||||
"""
|
||||
Args:
|
||||
@@ -1241,6 +1396,86 @@ class PrefixCacheManager:
|
||||
node.req_id_set.add(req_id)
|
||||
node = node.parent
|
||||
|
||||
def mm_build_path(self, request, num_computed_tokens, block_size, last_node, num_cached_tokens):
|
||||
"""
|
||||
Constructs a caching path in radix tree for multimodal requests by processing computed tokens.
|
||||
|
||||
Args:
|
||||
request: The inference request object containing:
|
||||
- prompt_token_ids: Original input tokens (List[int] or np.ndarray)
|
||||
- output_token_ids: Generated tokens (List[int])
|
||||
- mm_positions: Optional image positions for multimodal content
|
||||
num_computed_tokens: Total tokens processed so far (cached + newly computed)
|
||||
block_size: Fixed size of token blocks (must match cache configuration)
|
||||
last_node: The deepest existing BlockNode in the radix tree for this request
|
||||
num_cached_tokens: Number of tokens already cached
|
||||
|
||||
Returns:
|
||||
BlockNode: The new deepest node in the constructed path
|
||||
"""
|
||||
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = request.prompt_token_ids
|
||||
input_ids = prompt_token_ids + request.output_token_ids
|
||||
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
|
||||
if num_cached_tokens == can_cache_computed_tokens:
|
||||
return last_node
|
||||
|
||||
mm_idx = 0
|
||||
node = last_node
|
||||
unique_node_ids = []
|
||||
new_last_node = last_node
|
||||
has_unfilled_block = False
|
||||
current_time = time.time()
|
||||
|
||||
input_hash_value = self.hash_block_features(input_ids)
|
||||
gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy()
|
||||
for i in range(num_cached_tokens, can_cache_computed_tokens, block_size):
|
||||
current_block = input_ids[i : i + block_size]
|
||||
current_block_size = len(current_block) # 最后一个block可能没填满
|
||||
if current_block_size != block_size:
|
||||
has_unfilled_block = True
|
||||
else:
|
||||
mm_idx, extra_keys = self.get_block_hash_extra_keys(
|
||||
request=request,
|
||||
start_idx=i,
|
||||
end_idx=i + block_size,
|
||||
mm_idx=mm_idx,
|
||||
)
|
||||
hash_value = self.hash_block_features(current_block, extra_keys)
|
||||
allocated_block_id = gpu_block_ids.pop(0)
|
||||
node_id = self.node_id_pool.pop()
|
||||
unique_node_ids.append(node_id)
|
||||
new_last_node = BlockNode(
|
||||
node_id,
|
||||
input_ids,
|
||||
input_hash_value,
|
||||
node.depth + 1,
|
||||
allocated_block_id,
|
||||
current_block_size,
|
||||
hash_value,
|
||||
current_time,
|
||||
parent=node,
|
||||
shared_count=1,
|
||||
reverved_dec_block_ids=[],
|
||||
)
|
||||
new_last_node.req_id_set.add(request.request_id)
|
||||
self.node_map[node_id] = new_last_node
|
||||
node.children[hash_value] = new_last_node
|
||||
node = new_last_node
|
||||
|
||||
reverved_dec_block_ids = []
|
||||
if has_unfilled_block is True:
|
||||
reverved_dec_block_ids.append(gpu_block_ids.pop(0))
|
||||
|
||||
if new_last_node == self.radix_tree_root:
|
||||
self.unfilled_req_block_map[request.request_id] = reverved_dec_block_ids
|
||||
else:
|
||||
new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids)
|
||||
logger.info(f"build_path: allocate unique node ids {unique_node_ids} for req_id {request.request_id}")
|
||||
return new_last_node
|
||||
|
||||
def build_path(
|
||||
self,
|
||||
req_id,
|
||||
|
||||
Reference in New Issue
Block a user