mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support mm disable_chunked (#4803)
* support mm disable_chunked * update code * update code * update code
This commit is contained in:
@@ -598,6 +598,21 @@ class PrefixCacheManager:
|
||||
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def is_chunked_mm_input(self, mm_inputs, matched_token_num):
|
||||
"""
|
||||
check if mm_inputs is chunked
|
||||
"""
|
||||
if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0:
|
||||
return False, 0
|
||||
|
||||
for idx in range(len(mm_inputs["mm_positions"])):
|
||||
position = mm_inputs["mm_positions"][idx]
|
||||
if position.offset < matched_token_num < position.offset + position.length:
|
||||
return True, idx
|
||||
elif matched_token_num < position.offset:
|
||||
break
|
||||
return False, 0
|
||||
|
||||
def request_match_blocks(self, task, block_size, *args):
|
||||
"""
|
||||
get match blocks info for a task.
|
||||
@@ -617,9 +632,12 @@ class PrefixCacheManager:
|
||||
"""
|
||||
with self.request_release_lock:
|
||||
try:
|
||||
hit_info = {}
|
||||
hit_info["gpu_cache_blocks"] = 0
|
||||
hit_info["cpu_cache_blocks"] = 0
|
||||
hit_info = {
|
||||
"gpu_cache_blocks": 0,
|
||||
"cpu_cache_blocks": 0,
|
||||
"gpu_match_token_num": 0,
|
||||
"cpu_match_token_num": 0,
|
||||
}
|
||||
self.metrics.req_count += 1
|
||||
if isinstance(task.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||
@@ -673,8 +691,10 @@ class PrefixCacheManager:
|
||||
gpu_match_token_num,
|
||||
input_token_num,
|
||||
)
|
||||
hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
|
||||
hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size
|
||||
hit_info["gpu_cache_blocks"] = len(match_gpu_block_ids)
|
||||
hit_info["cpu_cache_blocks"] = len(match_cpu_block_ids)
|
||||
hit_info["gpu_match_token_num"] = gpu_match_token_num
|
||||
hit_info["cpu_match_token_num"] = cpu_match_token_num
|
||||
self.metrics._update_history_hit_metrics()
|
||||
if self.metrics.req_count % 10000 == 0:
|
||||
self.metrics.reset_metrics()
|
||||
@@ -685,8 +705,8 @@ class PrefixCacheManager:
|
||||
self.req_leaf_map[req_id] = match_block_node
|
||||
self.leaf_req_map[match_block_node].add(req_id)
|
||||
# record request cache info
|
||||
self.cache_info[req_id] = (match_block_node, matched_token_num)
|
||||
task.cached_block_num = matched_token_num // block_size
|
||||
self.cache_info[req_id] = (match_block_node, len(common_block_ids) * block_size)
|
||||
task.cached_block_num = len(common_block_ids)
|
||||
return common_block_ids, matched_token_num, hit_info
|
||||
except Exception as e:
|
||||
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
||||
@@ -1202,6 +1222,64 @@ class PrefixCacheManager:
|
||||
"""
|
||||
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
|
||||
|
||||
def _revert_match_blocks(
|
||||
self,
|
||||
request,
|
||||
matched_token_num: int,
|
||||
block_size: int,
|
||||
chunk_idx: int,
|
||||
match_node_ids: list,
|
||||
matche_nodes: list,
|
||||
match_gpu_block_ids: list,
|
||||
match_cpu_block_ids: list,
|
||||
gpu_match_token_num: int,
|
||||
cpu_match_token_num: int,
|
||||
swap_node_ids: list,
|
||||
):
|
||||
position = request.multimodal_inputs["mm_positions"][chunk_idx]
|
||||
revert_tokens = matched_token_num - position.offset
|
||||
match_block_ids = [node.block_id for node in matche_nodes]
|
||||
logger.warning(
|
||||
f"match_block: req_id {request.request_id} revert tokens: {revert_tokens} from matched nodes: {match_block_ids}"
|
||||
)
|
||||
while revert_tokens >= block_size:
|
||||
if len(matche_nodes) == 0:
|
||||
logger.error(f"req_id {request.request_id} revert nodes error, tokens: {revert_tokens}")
|
||||
break
|
||||
revert_tokens -= block_size
|
||||
revert_block = matche_nodes.pop()
|
||||
revert_block_id = revert_block.block_id
|
||||
if revert_block_id in match_gpu_block_ids:
|
||||
match_gpu_block_ids.remove(revert_block_id)
|
||||
match_node_ids.remove(revert_block.node_id)
|
||||
gpu_match_token_num -= block_size
|
||||
elif revert_block_id in match_cpu_block_ids:
|
||||
match_cpu_block_ids.remove(revert_block_id)
|
||||
match_node_ids.remove(revert_block.node_id)
|
||||
cpu_match_token_num -= block_size
|
||||
else:
|
||||
logger.error(
|
||||
f"req_id {request.request_id} revert nodes error, nodes: {revert_block_id}, "
|
||||
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
|
||||
)
|
||||
break
|
||||
if revert_block_id in swap_node_ids:
|
||||
swap_node_ids.remove(revert_block_id)
|
||||
|
||||
if revert_tokens > 0:
|
||||
last_block_id = matche_nodes[-1].block_id
|
||||
if last_block_id in match_gpu_block_ids:
|
||||
gpu_match_token_num -= revert_tokens
|
||||
elif last_block_id in match_cpu_block_ids:
|
||||
cpu_match_token_num -= revert_tokens
|
||||
else:
|
||||
logger.error(
|
||||
f"req_id {request.request_id} revert nodes error, revert_tokens: {revert_tokens}, nodes: {last_block_id}, "
|
||||
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
|
||||
)
|
||||
current_node = self.radix_tree_root if len(matche_nodes) == 0 else matche_nodes[-1]
|
||||
return gpu_match_token_num, cpu_match_token_num, current_node
|
||||
|
||||
def mm_match_block(self, request, block_size):
|
||||
"""
|
||||
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
|
||||
@@ -1290,6 +1368,28 @@ class PrefixCacheManager:
|
||||
if has_modified_cpu_lru_leaf_heap:
|
||||
heapq.heapify(self.cpu_lru_leaf_heap)
|
||||
|
||||
if self.cache_config.disable_chunked_mm_input:
|
||||
matched_token_num = gpu_match_token_num + cpu_match_token_num
|
||||
is_chunked, chunk_idx = self.is_chunked_mm_input(request.multimodal_inputs, matched_token_num)
|
||||
if is_chunked:
|
||||
(
|
||||
gpu_match_token_num,
|
||||
cpu_match_token_num,
|
||||
current_match_node,
|
||||
) = self._revert_match_blocks(
|
||||
request=request,
|
||||
matched_token_num=matched_token_num,
|
||||
block_size=block_size,
|
||||
chunk_idx=chunk_idx,
|
||||
match_node_ids=match_node_ids,
|
||||
matche_nodes=matche_nodes,
|
||||
match_gpu_block_ids=match_gpu_block_ids,
|
||||
match_cpu_block_ids=match_cpu_block_ids,
|
||||
gpu_match_token_num=gpu_match_token_num,
|
||||
cpu_match_token_num=cpu_match_token_num,
|
||||
swap_node_ids=swap_node_ids,
|
||||
)
|
||||
|
||||
logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
|
||||
return (
|
||||
match_gpu_block_ids,
|
||||
|
||||
Reference in New Issue
Block a user