mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support mm disable_chunked (#4803)
* support mm disable_chunked * update code * update code * update code
This commit is contained in:
@@ -598,6 +598,21 @@ class PrefixCacheManager:
|
||||
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def is_chunked_mm_input(self, mm_inputs, matched_token_num):
|
||||
"""
|
||||
check if mm_inputs is chunked
|
||||
"""
|
||||
if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0:
|
||||
return False, 0
|
||||
|
||||
for idx in range(len(mm_inputs["mm_positions"])):
|
||||
position = mm_inputs["mm_positions"][idx]
|
||||
if position.offset < matched_token_num < position.offset + position.length:
|
||||
return True, idx
|
||||
elif matched_token_num < position.offset:
|
||||
break
|
||||
return False, 0
|
||||
|
||||
def request_match_blocks(self, task, block_size, *args):
|
||||
"""
|
||||
get match blocks info for a task.
|
||||
@@ -617,9 +632,12 @@ class PrefixCacheManager:
|
||||
"""
|
||||
with self.request_release_lock:
|
||||
try:
|
||||
hit_info = {}
|
||||
hit_info["gpu_cache_blocks"] = 0
|
||||
hit_info["cpu_cache_blocks"] = 0
|
||||
hit_info = {
|
||||
"gpu_cache_blocks": 0,
|
||||
"cpu_cache_blocks": 0,
|
||||
"gpu_match_token_num": 0,
|
||||
"cpu_match_token_num": 0,
|
||||
}
|
||||
self.metrics.req_count += 1
|
||||
if isinstance(task.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||
@@ -673,8 +691,10 @@ class PrefixCacheManager:
|
||||
gpu_match_token_num,
|
||||
input_token_num,
|
||||
)
|
||||
hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
|
||||
hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size
|
||||
hit_info["gpu_cache_blocks"] = len(match_gpu_block_ids)
|
||||
hit_info["cpu_cache_blocks"] = len(match_cpu_block_ids)
|
||||
hit_info["gpu_match_token_num"] = gpu_match_token_num
|
||||
hit_info["cpu_match_token_num"] = cpu_match_token_num
|
||||
self.metrics._update_history_hit_metrics()
|
||||
if self.metrics.req_count % 10000 == 0:
|
||||
self.metrics.reset_metrics()
|
||||
@@ -685,8 +705,8 @@ class PrefixCacheManager:
|
||||
self.req_leaf_map[req_id] = match_block_node
|
||||
self.leaf_req_map[match_block_node].add(req_id)
|
||||
# record request cache info
|
||||
self.cache_info[req_id] = (match_block_node, matched_token_num)
|
||||
task.cached_block_num = matched_token_num // block_size
|
||||
self.cache_info[req_id] = (match_block_node, len(common_block_ids) * block_size)
|
||||
task.cached_block_num = len(common_block_ids)
|
||||
return common_block_ids, matched_token_num, hit_info
|
||||
except Exception as e:
|
||||
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
||||
@@ -1202,6 +1222,64 @@ class PrefixCacheManager:
|
||||
"""
|
||||
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
|
||||
|
||||
def _revert_match_blocks(
|
||||
self,
|
||||
request,
|
||||
matched_token_num: int,
|
||||
block_size: int,
|
||||
chunk_idx: int,
|
||||
match_node_ids: list,
|
||||
matche_nodes: list,
|
||||
match_gpu_block_ids: list,
|
||||
match_cpu_block_ids: list,
|
||||
gpu_match_token_num: int,
|
||||
cpu_match_token_num: int,
|
||||
swap_node_ids: list,
|
||||
):
|
||||
position = request.multimodal_inputs["mm_positions"][chunk_idx]
|
||||
revert_tokens = matched_token_num - position.offset
|
||||
match_block_ids = [node.block_id for node in matche_nodes]
|
||||
logger.warning(
|
||||
f"match_block: req_id {request.request_id} revert tokens: {revert_tokens} from matched nodes: {match_block_ids}"
|
||||
)
|
||||
while revert_tokens >= block_size:
|
||||
if len(matche_nodes) == 0:
|
||||
logger.error(f"req_id {request.request_id} revert nodes error, tokens: {revert_tokens}")
|
||||
break
|
||||
revert_tokens -= block_size
|
||||
revert_block = matche_nodes.pop()
|
||||
revert_block_id = revert_block.block_id
|
||||
if revert_block_id in match_gpu_block_ids:
|
||||
match_gpu_block_ids.remove(revert_block_id)
|
||||
match_node_ids.remove(revert_block.node_id)
|
||||
gpu_match_token_num -= block_size
|
||||
elif revert_block_id in match_cpu_block_ids:
|
||||
match_cpu_block_ids.remove(revert_block_id)
|
||||
match_node_ids.remove(revert_block.node_id)
|
||||
cpu_match_token_num -= block_size
|
||||
else:
|
||||
logger.error(
|
||||
f"req_id {request.request_id} revert nodes error, nodes: {revert_block_id}, "
|
||||
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
|
||||
)
|
||||
break
|
||||
if revert_block_id in swap_node_ids:
|
||||
swap_node_ids.remove(revert_block_id)
|
||||
|
||||
if revert_tokens > 0:
|
||||
last_block_id = matche_nodes[-1].block_id
|
||||
if last_block_id in match_gpu_block_ids:
|
||||
gpu_match_token_num -= revert_tokens
|
||||
elif last_block_id in match_cpu_block_ids:
|
||||
cpu_match_token_num -= revert_tokens
|
||||
else:
|
||||
logger.error(
|
||||
f"req_id {request.request_id} revert nodes error, revert_tokens: {revert_tokens}, nodes: {last_block_id}, "
|
||||
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
|
||||
)
|
||||
current_node = self.radix_tree_root if len(matche_nodes) == 0 else matche_nodes[-1]
|
||||
return gpu_match_token_num, cpu_match_token_num, current_node
|
||||
|
||||
def mm_match_block(self, request, block_size):
|
||||
"""
|
||||
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
|
||||
@@ -1290,6 +1368,28 @@ class PrefixCacheManager:
|
||||
if has_modified_cpu_lru_leaf_heap:
|
||||
heapq.heapify(self.cpu_lru_leaf_heap)
|
||||
|
||||
if self.cache_config.disable_chunked_mm_input:
|
||||
matched_token_num = gpu_match_token_num + cpu_match_token_num
|
||||
is_chunked, chunk_idx = self.is_chunked_mm_input(request.multimodal_inputs, matched_token_num)
|
||||
if is_chunked:
|
||||
(
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
current_match_node,
|
||||
) = self._revert_match_blocks(
|
||||
request=request,
|
||||
matched_token_num=matched_token_num,
|
||||
block_size=block_size,
|
||||
chunk_idx=chunk_idx,
|
||||
match_node_ids=match_node_ids,
|
||||
matche_nodes=matche_nodes,
|
||||
match_gpu_block_ids=match_gpu_block_ids,
|
||||
match_cpu_block_ids=match_cpu_block_ids,
|
||||
gpu_match_token_num=gpu_match_token_num,
|
||||
cpu_match_token_num=cpu_match_token_num,
|
||||
swap_node_ids=swap_node_ids,
|
||||
)
|
||||
|
||||
logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
|
||||
return (
|
||||
match_gpu_block_ids,
|
||||
|
||||
@@ -1211,6 +1211,7 @@ class CacheConfig:
|
||||
self.swap_space = None
|
||||
self.max_encoder_cache = None
|
||||
self.max_processor_cache = None
|
||||
self.disable_chunked_mm_input = False
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
@@ -314,6 +314,10 @@ class EngineArgs:
|
||||
"""
|
||||
additional decode block num
|
||||
"""
|
||||
disable_chunked_mm_input: bool = False
|
||||
"""
|
||||
Disable chunked_mm_input for multi-model inference.
|
||||
"""
|
||||
|
||||
scheduler_name: str = "local"
|
||||
"""
|
||||
@@ -936,6 +940,13 @@ class EngineArgs:
|
||||
help="ports for rdma communication.",
|
||||
)
|
||||
|
||||
perf_group.add_argument(
|
||||
"--disable-chunked-mm-input",
|
||||
action="store_true",
|
||||
default=EngineArgs.disable_chunked_mm_input,
|
||||
help="Disable chunked mm input.",
|
||||
)
|
||||
|
||||
# Router parameters group
|
||||
router_group = parser.add_argument_group("Router")
|
||||
router_group.add_argument(
|
||||
|
||||
@@ -771,8 +771,8 @@ class ResourceManagerV1(ResourceManager):
|
||||
)
|
||||
|
||||
request.num_cached_tokens = matched_token_num
|
||||
request.gpu_cache_token_num = hit_info["gpu_cache_blocks"] * self.config.cache_config.block_size
|
||||
request.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.config.cache_config.block_size
|
||||
request.gpu_cache_token_num = hit_info["gpu_match_token_num"]
|
||||
request.cpu_cache_token_num = hit_info["cpu_match_token_num"]
|
||||
request.cache_info = (matched_block_num, no_cache_block_num)
|
||||
request.block_tables = common_block_ids
|
||||
request.skip_allocate = False
|
||||
|
||||
300
tests/v1/cache_manager/test_revert_blocks.py
Normal file
300
tests/v1/cache_manager/test_revert_blocks.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from dataclasses import asdict
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode
|
||||
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
|
||||
from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.request import ImagePosition, Request
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
|
||||
|
||||
def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_override=100, max_num_batched_tokens=3200):
|
||||
engine_args = EngineArgs(
|
||||
max_num_seqs=max_num_seqs,
|
||||
num_gpu_blocks_override=num_gpu_blocks_override,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
args = asdict(engine_args)
|
||||
cache_cfg = CacheConfig(args)
|
||||
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=8192)
|
||||
speculative_cfg = SimpleNamespace(method=None)
|
||||
model_cfg.print = print
|
||||
cache_cfg.bytes_per_layer_per_block = 1
|
||||
parallel_cfg = ParallelConfig(args)
|
||||
scheduler_cfg = SchedulerConfig(args)
|
||||
graph_opt_cfg = engine_args.create_graph_optimization_config()
|
||||
fd_config = FDConfig(
|
||||
model_config=model_cfg,
|
||||
cache_config=cache_cfg,
|
||||
parallel_config=parallel_cfg,
|
||||
graph_opt_config=graph_opt_cfg,
|
||||
speculative_config=speculative_cfg,
|
||||
scheduler_config=scheduler_cfg,
|
||||
)
|
||||
return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed")
|
||||
|
||||
|
||||
class TestIsChunkedMMInput(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100)
|
||||
|
||||
def test_is_chunked_mm_input_none_input(self):
|
||||
result, idx = self.cache_manager.is_chunked_mm_input(None, 10)
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(idx, 0)
|
||||
|
||||
def test_is_chunked_mm_input_no_mm_positions(self):
|
||||
mm_inputs = {"other_field": "value"}
|
||||
result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 10)
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(idx, 0)
|
||||
|
||||
def test_is_chunked_mm_input_empty_positions(self):
|
||||
mm_inputs = {"mm_positions": []}
|
||||
result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 10)
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(idx, 0)
|
||||
|
||||
def test_is_chunked_mm_input_matched_in_chunk(self):
|
||||
mm_inputs = {
|
||||
"mm_positions": [
|
||||
ImagePosition(offset=5, length=10),
|
||||
ImagePosition(offset=20, length=10),
|
||||
]
|
||||
}
|
||||
result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 8)
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(idx, 0)
|
||||
|
||||
def test_is_chunked_mm_input_matched_in_second_chunk(self):
|
||||
mm_inputs = {
|
||||
"mm_positions": [
|
||||
ImagePosition(offset=5, length=10),
|
||||
ImagePosition(offset=20, length=10),
|
||||
]
|
||||
}
|
||||
result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 25)
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(idx, 1)
|
||||
|
||||
def test_is_chunked_mm_input_before_first_chunk(self):
|
||||
mm_inputs = {
|
||||
"mm_positions": [
|
||||
ImagePosition(offset=5, length=10),
|
||||
ImagePosition(offset=20, length=10),
|
||||
]
|
||||
}
|
||||
result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 3)
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(idx, 0)
|
||||
|
||||
def test_is_chunked_mm_input_after_last_chunk(self):
|
||||
mm_inputs = {
|
||||
"mm_positions": [
|
||||
ImagePosition(offset=5, length=10),
|
||||
ImagePosition(offset=20, length=10),
|
||||
]
|
||||
}
|
||||
result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 35)
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(idx, 0)
|
||||
|
||||
|
||||
class TestRevertMatchBlocks(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 64
|
||||
self.cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100)
|
||||
|
||||
def make_match_blocks(self, gpu_block_num, cpu_block_num):
|
||||
block_num = gpu_block_num + cpu_block_num
|
||||
matched_token_num = block_num * self.block_size
|
||||
match_node_ids = []
|
||||
matche_nodes = []
|
||||
match_gpu_block_ids = []
|
||||
match_cpu_block_ids = []
|
||||
for idx in range(block_num):
|
||||
node_id = idx + 10
|
||||
block = BlockNode(node_id, [], 0, 0, idx, 0, None, None, None)
|
||||
match_node_ids.append(node_id)
|
||||
matche_nodes.append(block)
|
||||
match_gpu_block_ids.append(idx)
|
||||
|
||||
for _ in range(cpu_block_num):
|
||||
match_cpu_block_ids.append(match_gpu_block_ids.pop())
|
||||
|
||||
gpu_match_token_num = len(match_gpu_block_ids) * self.block_size
|
||||
cpu_match_token_num = len(match_cpu_block_ids) * self.block_size
|
||||
return (
|
||||
matched_token_num,
|
||||
match_node_ids,
|
||||
matche_nodes,
|
||||
match_gpu_block_ids,
|
||||
match_cpu_block_ids,
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
)
|
||||
|
||||
def test_revert_full_blocks(self):
|
||||
# Setup test data
|
||||
multimodal_inputs = {
|
||||
"mm_positions": [ImagePosition(offset=0, length=1200)],
|
||||
"mm_hashes": ["image1"],
|
||||
}
|
||||
req_dict = {
|
||||
"request_id": "req1",
|
||||
"prompt_token_ids": [-1] * 1200 + [2] * 120,
|
||||
"prompt_token_ids_len": 1320,
|
||||
"multimodal_inputs": multimodal_inputs,
|
||||
}
|
||||
|
||||
(
|
||||
matched_token_num,
|
||||
match_node_ids,
|
||||
matche_nodes,
|
||||
match_gpu_block_ids,
|
||||
match_cpu_block_ids,
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
) = self.make_match_blocks(gpu_block_num=2, cpu_block_num=0)
|
||||
|
||||
# Call method
|
||||
(
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
current_match_node,
|
||||
) = self.cache_manager._revert_match_blocks(
|
||||
request=Request.from_dict(req_dict),
|
||||
matched_token_num=matched_token_num,
|
||||
block_size=self.block_size,
|
||||
chunk_idx=0,
|
||||
match_node_ids=match_node_ids,
|
||||
matche_nodes=matche_nodes,
|
||||
match_gpu_block_ids=match_gpu_block_ids,
|
||||
match_cpu_block_ids=match_cpu_block_ids,
|
||||
gpu_match_token_num=gpu_match_token_num,
|
||||
cpu_match_token_num=cpu_match_token_num,
|
||||
swap_node_ids=[],
|
||||
)
|
||||
|
||||
# Assertions
|
||||
self.assertEqual(gpu_match_token_num, 0)
|
||||
self.assertEqual(cpu_match_token_num, 0)
|
||||
self.assertEqual(len(match_node_ids), 0)
|
||||
self.assertEqual(len(match_gpu_block_ids), 0)
|
||||
|
||||
def test_revert_partial_block(self):
|
||||
# Setup test data
|
||||
multimodal_inputs = {
|
||||
"mm_positions": [ImagePosition(offset=120, length=1200)],
|
||||
"mm_hashes": ["image1"],
|
||||
}
|
||||
req_dict = {
|
||||
"request_id": "req1",
|
||||
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120,
|
||||
"prompt_token_ids_len": 1440,
|
||||
"multimodal_inputs": multimodal_inputs,
|
||||
}
|
||||
|
||||
(
|
||||
matched_token_num,
|
||||
match_node_ids,
|
||||
matche_nodes,
|
||||
match_gpu_block_ids,
|
||||
match_cpu_block_ids,
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
) = self.make_match_blocks(gpu_block_num=20, cpu_block_num=0)
|
||||
|
||||
# Call method
|
||||
(
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
current_match_node,
|
||||
) = self.cache_manager._revert_match_blocks(
|
||||
request=Request.from_dict(req_dict),
|
||||
matched_token_num=matched_token_num,
|
||||
block_size=self.block_size,
|
||||
chunk_idx=0,
|
||||
match_node_ids=match_node_ids,
|
||||
matche_nodes=matche_nodes,
|
||||
match_gpu_block_ids=match_gpu_block_ids,
|
||||
match_cpu_block_ids=match_cpu_block_ids,
|
||||
gpu_match_token_num=gpu_match_token_num,
|
||||
cpu_match_token_num=cpu_match_token_num,
|
||||
swap_node_ids=[],
|
||||
)
|
||||
|
||||
# Assertions
|
||||
self.assertEqual(gpu_match_token_num, 120)
|
||||
self.assertEqual(cpu_match_token_num, 0)
|
||||
self.assertEqual(len(match_node_ids), 2)
|
||||
self.assertEqual(len(match_gpu_block_ids), 2)
|
||||
|
||||
def test_revert_with_cpu_blocks(self):
|
||||
# Setup test data
|
||||
multimodal_inputs = {
|
||||
"mm_positions": [ImagePosition(offset=120, length=1200), ImagePosition(offset=1440, length=420)],
|
||||
"mm_hashes": ["image1", "image2"],
|
||||
}
|
||||
req_dict = {
|
||||
"request_id": "req1",
|
||||
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [-1] * 420,
|
||||
"prompt_token_ids_len": 1860,
|
||||
"multimodal_inputs": multimodal_inputs,
|
||||
}
|
||||
|
||||
(
|
||||
matched_token_num,
|
||||
match_node_ids,
|
||||
matche_nodes,
|
||||
match_gpu_block_ids,
|
||||
match_cpu_block_ids,
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
) = self.make_match_blocks(gpu_block_num=22, cpu_block_num=6)
|
||||
|
||||
# Call method
|
||||
(
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
current_match_node,
|
||||
) = self.cache_manager._revert_match_blocks(
|
||||
request=Request.from_dict(req_dict),
|
||||
matched_token_num=matched_token_num,
|
||||
block_size=self.block_size,
|
||||
chunk_idx=1,
|
||||
match_node_ids=match_node_ids,
|
||||
matche_nodes=matche_nodes,
|
||||
match_gpu_block_ids=match_gpu_block_ids,
|
||||
match_cpu_block_ids=match_cpu_block_ids,
|
||||
gpu_match_token_num=gpu_match_token_num,
|
||||
cpu_match_token_num=cpu_match_token_num,
|
||||
swap_node_ids=[],
|
||||
)
|
||||
|
||||
# Assertions
|
||||
self.assertEqual(gpu_match_token_num, 22 * self.block_size)
|
||||
self.assertEqual(cpu_match_token_num, 32)
|
||||
self.assertEqual(len(match_node_ids), 23)
|
||||
self.assertEqual(len(match_gpu_block_ids), 22)
|
||||
self.assertEqual(len(match_cpu_block_ids), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user