mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] block scheduler v1 support prefix caching (#3061)
* block scheduler v1 support prefix cache * update code * update code * fix code bug * add timeout time --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -93,6 +93,7 @@ class PrefixCacheManager:
|
|||||||
self.req_leaf_map = {} # {request_id: leaf node}
|
self.req_leaf_map = {} # {request_id: leaf node}
|
||||||
self.leaf_req_map = defaultdict(set)
|
self.leaf_req_map = defaultdict(set)
|
||||||
self.unfilled_req_block_map = defaultdict(list)
|
self.unfilled_req_block_map = defaultdict(list)
|
||||||
|
self.cache_info = {}
|
||||||
|
|
||||||
self.executor_pool = ThreadPoolExecutor(max_workers=1)
|
self.executor_pool = ThreadPoolExecutor(max_workers=1)
|
||||||
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
|
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
|
||||||
@@ -425,6 +426,135 @@ class PrefixCacheManager:
|
|||||||
|
|
||||||
return gpu_recv_block_ids, gpu_extra_block_ids
|
return gpu_recv_block_ids, gpu_extra_block_ids
|
||||||
|
|
||||||
|
def get_required_block_num(self, input_token_num, block_size):
|
||||||
|
"""
|
||||||
|
get required block num by input token num and block size
|
||||||
|
"""
|
||||||
|
return (input_token_num + block_size - 1) // block_size
|
||||||
|
|
||||||
|
def update_cache_blocks(self, task, block_size):
|
||||||
|
"""
|
||||||
|
update cache blocks for a task.
|
||||||
|
# TODO(chengyanfu): support async update
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- task: Task
|
||||||
|
- block_size: Size per block (in tokens)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req_id = task.request_id
|
||||||
|
num_cached_tokens = task.num_cached_tokens
|
||||||
|
block_tables = task.block_tables
|
||||||
|
|
||||||
|
last_node, input_ids = self.cache_info[req_id]
|
||||||
|
left_input_ids = input_ids[num_cached_tokens:]
|
||||||
|
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
|
||||||
|
|
||||||
|
with self.request_release_lock:
|
||||||
|
current_time = time.time()
|
||||||
|
leaf_node = self.build_path(
|
||||||
|
req_id=req_id,
|
||||||
|
current_time=current_time,
|
||||||
|
input_ids=input_ids,
|
||||||
|
left_input_ids=left_input_ids,
|
||||||
|
gpu_block_ids=gpu_extra_block_ids,
|
||||||
|
block_size=block_size,
|
||||||
|
last_node=last_node,
|
||||||
|
reverved_dec_block_num=0,
|
||||||
|
)
|
||||||
|
self.req_leaf_map[req_id] = leaf_node
|
||||||
|
self.leaf_req_map[leaf_node].add(req_id)
|
||||||
|
self.cache_info[req_id] = (leaf_node, input_ids)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"update_cache_blocks, error: {type(e)} {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def request_match_blocks(self, task, block_size, *args):
|
||||||
|
"""
|
||||||
|
get match blocks info for a task.
|
||||||
|
This is a synchronous interface. If CPU-to-GPU data transfer occurs,
|
||||||
|
it will block until synchronization completes.
|
||||||
|
Callers requiring asynchronous behavior should invoke this via a thread pool.
|
||||||
|
|
||||||
|
Note: This function may allocate GPU blocks for matched CPU Cache
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- task: Task dictionary
|
||||||
|
- block_size: Size per block (in tokens)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- common_block_ids: List of matched shared blocks
|
||||||
|
- unique_block_ids: List of exclusively allocated blocks
|
||||||
|
"""
|
||||||
|
with self.request_release_lock:
|
||||||
|
try:
|
||||||
|
hit_info = {}
|
||||||
|
hit_info["gpu_cache_blocks"] = 0
|
||||||
|
hit_info["cpu_cache_blocks"] = 0
|
||||||
|
self.metrics.req_count += 1
|
||||||
|
input_ids = task.prompt_token_ids
|
||||||
|
req_id = task.request_id
|
||||||
|
logger.info(f"request_block_ids: start to allocate blocks for req_id {req_id}")
|
||||||
|
input_token_num = len(input_ids)
|
||||||
|
common_block_ids = []
|
||||||
|
# 1. match block
|
||||||
|
(
|
||||||
|
match_gpu_block_ids,
|
||||||
|
match_cpu_block_ids,
|
||||||
|
swap_node_ids,
|
||||||
|
match_block_node,
|
||||||
|
gpu_match_token_num,
|
||||||
|
cpu_match_token_num,
|
||||||
|
) = self.match_block(req_id, input_ids, block_size)
|
||||||
|
|
||||||
|
# update matched node info
|
||||||
|
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
|
||||||
|
|
||||||
|
# 2. prepare cache
|
||||||
|
# allocate gpu cache for matched cpu blocks
|
||||||
|
gpu_recv_block_ids = []
|
||||||
|
match_cpu_blocks_num = len(match_cpu_block_ids)
|
||||||
|
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
|
||||||
|
if match_cpu_blocks_num > 0:
|
||||||
|
gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num)
|
||||||
|
if len(gpu_recv_block_ids) > 0:
|
||||||
|
self._prepare_cpu_cache(
|
||||||
|
req_id=req_id,
|
||||||
|
swap_node_ids=swap_node_ids,
|
||||||
|
gpu_recv_block_ids=gpu_recv_block_ids,
|
||||||
|
match_cpu_block_ids=match_cpu_block_ids,
|
||||||
|
cpu_recv_block_ids=[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Not enough GPU memory to allocate cache for matched CPU Cache")
|
||||||
|
|
||||||
|
# record request cache info
|
||||||
|
self.cache_info[req_id] = (match_block_node, input_ids)
|
||||||
|
|
||||||
|
# 3. update metrics
|
||||||
|
matched_token_num = gpu_match_token_num + cpu_match_token_num
|
||||||
|
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
|
||||||
|
if matched_token_num > 0:
|
||||||
|
self.metrics.hit_req_count += 1
|
||||||
|
self.metrics.calculate_hit_metrics(
|
||||||
|
req_id,
|
||||||
|
cpu_match_token_num,
|
||||||
|
gpu_match_token_num,
|
||||||
|
input_token_num,
|
||||||
|
)
|
||||||
|
hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
|
||||||
|
hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size
|
||||||
|
self.metrics._update_history_hit_metrics()
|
||||||
|
if self.metrics.req_count % 10000 == 0:
|
||||||
|
self.metrics.reset_metrics()
|
||||||
|
logger.info(
|
||||||
|
f"request_block_ids: request block for req_id {req_id}: common_block_ids {common_block_ids}"
|
||||||
|
)
|
||||||
|
return common_block_ids, matched_token_num, hit_info
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"request_block_ids: error: {type(e)} {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
def request_block_ids(self, task, block_size, dec_token_num, *args):
|
def request_block_ids(self, task, block_size, dec_token_num, *args):
|
||||||
"""
|
"""
|
||||||
Allocate blocks for a task.
|
Allocate blocks for a task.
|
||||||
@@ -463,12 +593,10 @@ class PrefixCacheManager:
|
|||||||
cpu_match_token_num,
|
cpu_match_token_num,
|
||||||
) = self.match_block(req_id, input_ids, block_size)
|
) = self.match_block(req_id, input_ids, block_size)
|
||||||
match_gpu_blocks_num = len(match_gpu_block_ids)
|
match_gpu_blocks_num = len(match_gpu_block_ids)
|
||||||
match_cpu_blocks_num = len(match_cpu_block_ids)
|
|
||||||
matched_block_num = match_gpu_blocks_num + match_cpu_blocks_num
|
|
||||||
matched_token_num_in_cpu_and_gpu = gpu_match_token_num + cpu_match_token_num
|
matched_token_num_in_cpu_and_gpu = gpu_match_token_num + cpu_match_token_num
|
||||||
# check enough gpu memory to allocate cache
|
# check enough gpu memory to allocate cache
|
||||||
block_num = (input_token_num + block_size - 1 + dec_token_num) // block_size
|
block_num = (input_token_num + block_size - 1 + dec_token_num) // block_size
|
||||||
self._check_validity(req_id, matched_block_num, block_num)
|
self._check_validity(req_id, match_gpu_blocks_num, block_num)
|
||||||
# update matched node info
|
# update matched node info
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
self._update_matched_node_info(req_id, match_block_node, current_time)
|
self._update_matched_node_info(req_id, match_block_node, current_time)
|
||||||
@@ -557,6 +685,9 @@ class PrefixCacheManager:
|
|||||||
node.decrement_shared_count()
|
node.decrement_shared_count()
|
||||||
node = node.parent
|
node = node.parent
|
||||||
|
|
||||||
|
if req_id in self.cache_info:
|
||||||
|
del self.cache_info[req_id]
|
||||||
|
|
||||||
logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
|
logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
|
||||||
|
|
||||||
if leaf_node == self.radix_tree_root:
|
if leaf_node == self.radix_tree_root:
|
||||||
|
@@ -373,6 +373,8 @@ class LLMEngine:
|
|||||||
int(self.resource_manager.available_batch()),
|
int(self.resource_manager.available_batch()),
|
||||||
self.cfg.max_prefill_batch,
|
self.cfg.max_prefill_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.resource_manager.check_and_free_block_tables()
|
||||||
tasks = self.scheduler.get_requests(
|
tasks = self.scheduler.get_requests(
|
||||||
available_blocks=self.resource_manager.available_block_num(),
|
available_blocks=self.resource_manager.available_block_num(),
|
||||||
block_size=self.cfg.cache_config.block_size,
|
block_size=self.cfg.cache_config.block_size,
|
||||||
|
@@ -80,6 +80,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
return len(request.block_tables) * self.config.cache_config.block_size
|
return len(request.block_tables) * self.config.cache_config.block_size
|
||||||
|
|
||||||
def get_new_block_nums(self, request: Request, num_new_tokens: int):
|
def get_new_block_nums(self, request: Request, num_new_tokens: int):
|
||||||
|
self.check_and_free_block_tables()
|
||||||
return (
|
return (
|
||||||
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
|
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
|
||||||
) // self.config.cache_config.block_size - len(request.block_tables)
|
) // self.config.cache_config.block_size - len(request.block_tables)
|
||||||
@@ -103,6 +104,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
preempted_req = self.running.pop()
|
preempted_req = self.running.pop()
|
||||||
preempted_req.status = RequestStatus.PREEMPTED
|
preempted_req.status = RequestStatus.PREEMPTED
|
||||||
preempted_req.num_computed_tokens = 0
|
preempted_req.num_computed_tokens = 0
|
||||||
|
preempted_req.prefill_block_num = 0
|
||||||
self._free_blocks(preempted_req)
|
self._free_blocks(preempted_req)
|
||||||
self.waiting.appendleft(preempted_req)
|
self.waiting.appendleft(preempted_req)
|
||||||
preempted_reqs.append(preempted_req)
|
preempted_reqs.append(preempted_req)
|
||||||
@@ -212,6 +214,14 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
|
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
|
||||||
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
||||||
request.num_computed_tokens = request.num_total_tokens - 1
|
request.num_computed_tokens = request.num_total_tokens - 1
|
||||||
|
else: # prefill finished
|
||||||
|
if (
|
||||||
|
self.config.cache_config.enable_prefix_caching
|
||||||
|
and request.get("prefill_block_num", None) is None
|
||||||
|
):
|
||||||
|
# update prefill cache blocks for prefix caching
|
||||||
|
request.prefill_block_num = len(request.block_tables)
|
||||||
|
self.cache_manager.update_cache_blocks(request, self.config.cache_config.block_size)
|
||||||
if (
|
if (
|
||||||
self.allocated_slots(request) - request.num_total_tokens
|
self.allocated_slots(request) - request.num_total_tokens
|
||||||
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
|
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
|
||||||
@@ -271,11 +281,18 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
break
|
break
|
||||||
request = self.waiting[0]
|
request = self.waiting[0]
|
||||||
if request.status == RequestStatus.WAITING:
|
if request.status == RequestStatus.WAITING:
|
||||||
|
# Enable prefix caching
|
||||||
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
|
success = self.get_prefix_cached_blocks(request)
|
||||||
|
if not success:
|
||||||
|
break
|
||||||
|
|
||||||
num_new_tokens = self._get_num_new_tokens(request, token_budget)
|
num_new_tokens = self._get_num_new_tokens(request, token_budget)
|
||||||
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
||||||
# Allocate blocks to prefill
|
# Allocate blocks to prefill
|
||||||
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
||||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
if not request.get("skip_allocate", False):
|
||||||
|
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||||
@@ -292,7 +309,9 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
elif request.status == RequestStatus.PREEMPTED:
|
elif request.status == RequestStatus.PREEMPTED:
|
||||||
request.need_prefill_tokens = request.num_total_tokens # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
|
request.need_prefill_tokens = (
|
||||||
|
request.num_total_tokens
|
||||||
|
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
|
||||||
num_new_tokens = self._get_num_new_tokens(request, token_budget)
|
num_new_tokens = self._get_num_new_tokens(request, token_budget)
|
||||||
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
||||||
# Allocate blocks to prefill
|
# Allocate blocks to prefill
|
||||||
@@ -327,12 +346,51 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
break
|
break
|
||||||
return self.real_bsz
|
return self.real_bsz
|
||||||
|
|
||||||
|
def get_prefix_cached_blocks(self, request: Request):
|
||||||
|
"""
|
||||||
|
set prefix cached information for the given request
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_prepare_time = time.time()
|
||||||
|
(common_block_ids, matched_token_num, hit_info) = self.cache_manager.request_match_blocks(
|
||||||
|
request, self.config.cache_config.block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
matched_block_num = len(common_block_ids)
|
||||||
|
no_cache_block_num = self.cache_manager.get_required_block_num(
|
||||||
|
request.prompt_token_ids_len - matched_token_num,
|
||||||
|
self.config.cache_config.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
request.num_cached_tokens = matched_token_num
|
||||||
|
request.gpu_cache_token_num = hit_info["gpu_cache_blocks"] * self.config.cache_config.block_size
|
||||||
|
request.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.config.cache_config.block_size
|
||||||
|
request.cache_info = (matched_block_num, no_cache_block_num)
|
||||||
|
request.block_tables = common_block_ids
|
||||||
|
request.skip_allocate = False
|
||||||
|
|
||||||
|
if matched_token_num == request.prompt_token_ids_len:
|
||||||
|
request.num_computed_tokens = matched_token_num - 1
|
||||||
|
request.skip_allocate = True
|
||||||
|
else:
|
||||||
|
request.num_computed_tokens = matched_token_num
|
||||||
|
request.cache_prepare_time = time.time() - cache_prepare_time
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
llm_logger.error(f"prefix match blocks error: {e}, waiting reschedule...")
|
||||||
|
return False
|
||||||
|
|
||||||
def add_request(self, request: Request) -> None:
|
def add_request(self, request: Request) -> None:
|
||||||
self.waiting.append(request)
|
self.waiting.append(request)
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
|
|
||||||
def _free_blocks(self, request: Request):
|
def _free_blocks(self, request: Request):
|
||||||
self.cache_manager.recycle_gpu_blocks(request.block_tables)
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
|
# TODO(chengyanfu): support cache ouput blocks for prefix caching
|
||||||
|
self.cache_manager.release_block_ids_async(request)
|
||||||
|
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
|
||||||
|
else:
|
||||||
|
self.cache_manager.recycle_gpu_blocks(request.block_tables)
|
||||||
request.block_tables = []
|
request.block_tables = []
|
||||||
|
|
||||||
def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
|
def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
|
||||||
|
@@ -27,7 +27,7 @@ for subdir in "$run_path"*/; do
|
|||||||
echo "------------------------------------------------------------"
|
echo "------------------------------------------------------------"
|
||||||
|
|
||||||
set +e
|
set +e
|
||||||
timeout 360 python -m pytest --disable-warnings -sv "$file"
|
timeout 600 python -m pytest --disable-warnings -sv "$file"
|
||||||
exit_code=$?
|
exit_code=$?
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user