mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Optimize] optimize prefix cache in develop (#3890)
* optimize prefix cache in release22 * fix * fix * fix * add ci for v1 * add unit test --------- Co-authored-by: xiegegege <46314656+xiegegege@users.noreply.github.com>
This commit is contained in:
@@ -257,7 +257,12 @@ class PrefixCacheManager:
|
|||||||
Check if num_blocks gpu blocks can be allocated.
|
Check if num_blocks gpu blocks can be allocated.
|
||||||
"""
|
"""
|
||||||
if len(self.gpu_free_block_list) < num_blocks:
|
if len(self.gpu_free_block_list) < num_blocks:
|
||||||
return False
|
if self.cache_config.enable_prefix_caching:
|
||||||
|
self.free_block_ids(num_blocks)
|
||||||
|
if len(self.gpu_free_block_list) < num_blocks:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -448,7 +453,7 @@ class PrefixCacheManager:
|
|||||||
"""
|
"""
|
||||||
return (input_token_num + block_size - 1) // block_size
|
return (input_token_num + block_size - 1) // block_size
|
||||||
|
|
||||||
def update_cache_blocks(self, task, block_size):
|
def update_cache_blocks(self, task, block_size, num_computed_tokens):
|
||||||
"""
|
"""
|
||||||
update cache blocks for a task.
|
update cache blocks for a task.
|
||||||
# TODO(chengyanfu): support async update
|
# TODO(chengyanfu): support async update
|
||||||
@@ -459,12 +464,19 @@ class PrefixCacheManager:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
req_id = task.request_id
|
req_id = task.request_id
|
||||||
num_cached_tokens = task.num_cached_tokens
|
|
||||||
block_tables = task.block_tables
|
block_tables = task.block_tables
|
||||||
|
|
||||||
last_node, input_ids = self.cache_info[req_id]
|
last_node, num_cached_tokens = self.cache_info[req_id]
|
||||||
left_input_ids = input_ids[num_cached_tokens:]
|
if isinstance(task.prompt_token_ids, np.ndarray):
|
||||||
|
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
prompt_token_ids = task.prompt_token_ids
|
||||||
|
input_ids = prompt_token_ids + task.output_token_ids
|
||||||
|
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
|
||||||
|
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
|
||||||
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
|
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
|
||||||
|
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
|
||||||
|
self.leaf_req_map[last_node].remove(req_id)
|
||||||
|
|
||||||
with self.request_release_lock:
|
with self.request_release_lock:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
@@ -480,7 +492,8 @@ class PrefixCacheManager:
|
|||||||
)
|
)
|
||||||
self.req_leaf_map[req_id] = leaf_node
|
self.req_leaf_map[req_id] = leaf_node
|
||||||
self.leaf_req_map[leaf_node].add(req_id)
|
self.leaf_req_map[leaf_node].add(req_id)
|
||||||
self.cache_info[req_id] = (leaf_node, input_ids)
|
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
|
||||||
|
task.cached_block_num = can_cache_computed_tokens // block_size
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||||
raise e
|
raise e
|
||||||
@@ -508,7 +521,11 @@ class PrefixCacheManager:
|
|||||||
hit_info["gpu_cache_blocks"] = 0
|
hit_info["gpu_cache_blocks"] = 0
|
||||||
hit_info["cpu_cache_blocks"] = 0
|
hit_info["cpu_cache_blocks"] = 0
|
||||||
self.metrics.req_count += 1
|
self.metrics.req_count += 1
|
||||||
input_ids = task.prompt_token_ids
|
if isinstance(task.prompt_token_ids, np.ndarray):
|
||||||
|
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
prompt_token_ids = task.prompt_token_ids
|
||||||
|
input_ids = prompt_token_ids + task.output_token_ids
|
||||||
req_id = task.request_id
|
req_id = task.request_id
|
||||||
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
|
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
|
||||||
input_token_num = len(input_ids)
|
input_token_num = len(input_ids)
|
||||||
@@ -546,9 +563,6 @@ class PrefixCacheManager:
|
|||||||
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
|
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
# record request cache info
|
|
||||||
self.cache_info[req_id] = (match_block_node, input_ids)
|
|
||||||
|
|
||||||
# 3. update metrics
|
# 3. update metrics
|
||||||
matched_token_num = gpu_match_token_num + cpu_match_token_num
|
matched_token_num = gpu_match_token_num + cpu_match_token_num
|
||||||
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
|
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
|
||||||
@@ -571,6 +585,9 @@ class PrefixCacheManager:
|
|||||||
# set leaf node temporarily, then update it in update_cache_blocks
|
# set leaf node temporarily, then update it in update_cache_blocks
|
||||||
self.req_leaf_map[req_id] = match_block_node
|
self.req_leaf_map[req_id] = match_block_node
|
||||||
self.leaf_req_map[match_block_node].add(req_id)
|
self.leaf_req_map[match_block_node].add(req_id)
|
||||||
|
# record request cache info
|
||||||
|
self.cache_info[req_id] = (match_block_node, matched_token_num)
|
||||||
|
task.cached_block_num = matched_token_num // block_size
|
||||||
return common_block_ids, matched_token_num, hit_info
|
return common_block_ids, matched_token_num, hit_info
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
||||||
@@ -687,6 +704,11 @@ class PrefixCacheManager:
|
|||||||
"""
|
"""
|
||||||
return self.executor_pool.submit(self.release_block_ids, task)
|
return self.executor_pool.submit(self.release_block_ids, task)
|
||||||
|
|
||||||
|
def free_block_ids(self, need_block_num):
|
||||||
|
self.free_block_ids_async(need_block_num)
|
||||||
|
while (self.gpu_free_task_future is not None) and (not self.gpu_free_task_future.done()):
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
def release_block_ids(self, task):
|
def release_block_ids(self, task):
|
||||||
"""
|
"""
|
||||||
release block ids
|
release block ids
|
||||||
@@ -1108,15 +1130,6 @@ class PrefixCacheManager:
|
|||||||
node.req_id_set.add(req_id)
|
node.req_id_set.add(req_id)
|
||||||
node = node.parent
|
node = node.parent
|
||||||
|
|
||||||
def decrease_request_share_count(self, req_id):
|
|
||||||
"""
|
|
||||||
Decrease node shared count
|
|
||||||
"""
|
|
||||||
node, input_ids = self.cache_info[req_id]
|
|
||||||
while node != self.radix_tree_root:
|
|
||||||
node.decrement_shared_count()
|
|
||||||
node = node.parent
|
|
||||||
|
|
||||||
def build_path(
|
def build_path(
|
||||||
self,
|
self,
|
||||||
req_id,
|
req_id,
|
||||||
|
@@ -525,10 +525,13 @@ class EngineService:
|
|||||||
int(self.resource_manager.available_batch()),
|
int(self.resource_manager.available_batch()),
|
||||||
self.cfg.max_prefill_batch,
|
self.cfg.max_prefill_batch,
|
||||||
)
|
)
|
||||||
|
if self.cfg.model_config.enable_mm:
|
||||||
|
available_blocks = self.resource_manager.available_block_num()
|
||||||
|
else:
|
||||||
|
available_blocks = self.cfg.cache_config.max_block_num_per_seq
|
||||||
|
|
||||||
self.resource_manager.check_and_free_block_tables()
|
|
||||||
tasks = self.scheduler.get_requests(
|
tasks = self.scheduler.get_requests(
|
||||||
available_blocks=self.resource_manager.available_block_num(),
|
available_blocks=available_blocks,
|
||||||
block_size=self.cfg.cache_config.block_size,
|
block_size=self.cfg.cache_config.block_size,
|
||||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||||
max_num_batched_tokens=self.cfg.max_model_len,
|
max_num_batched_tokens=self.cfg.max_model_len,
|
||||||
|
@@ -98,7 +98,6 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
return len(request.block_tables) * self.config.cache_config.block_size
|
return len(request.block_tables) * self.config.cache_config.block_size
|
||||||
|
|
||||||
def get_new_block_nums(self, request: Request, num_new_tokens: int):
|
def get_new_block_nums(self, request: Request, num_new_tokens: int):
|
||||||
self.check_and_free_block_tables()
|
|
||||||
block_num = (
|
block_num = (
|
||||||
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
|
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
|
||||||
) // self.config.cache_config.block_size - len(request.block_tables)
|
) // self.config.cache_config.block_size - len(request.block_tables)
|
||||||
@@ -137,7 +136,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
preempted_req.status = RequestStatus.PREEMPTED
|
preempted_req.status = RequestStatus.PREEMPTED
|
||||||
preempted_req.num_computed_tokens = 0
|
preempted_req.num_computed_tokens = 0
|
||||||
self._free_blocks(preempted_req)
|
self._free_blocks(preempted_req)
|
||||||
preempted_req.prefill_block_num = None
|
preempted_req.cached_block_num = 0
|
||||||
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
||||||
preempted_reqs.append(preempted_req)
|
preempted_reqs.append(preempted_req)
|
||||||
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
||||||
@@ -300,14 +299,6 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
|
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
|
||||||
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
||||||
request.num_computed_tokens = request.num_total_tokens - 1
|
request.num_computed_tokens = request.num_total_tokens - 1
|
||||||
else: # prefill finished
|
|
||||||
if (
|
|
||||||
self.config.cache_config.enable_prefix_caching
|
|
||||||
and request.get("prefill_block_num", None) is None
|
|
||||||
):
|
|
||||||
# update prefill cache blocks for prefix caching
|
|
||||||
request.prefill_block_num = len(request.block_tables)
|
|
||||||
self.cache_manager.update_cache_blocks(request, self.config.cache_config.block_size)
|
|
||||||
if (
|
if (
|
||||||
self.allocated_slots(request) - request.num_total_tokens
|
self.allocated_slots(request) - request.num_total_tokens
|
||||||
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
|
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
|
||||||
@@ -357,6 +348,10 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.num_computed_tokens += num_new_tokens
|
request.num_computed_tokens += num_new_tokens
|
||||||
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
|
self.cache_manager.update_cache_blocks(
|
||||||
|
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||||
|
)
|
||||||
req_index += 1
|
req_index += 1
|
||||||
# schedule the WAITING requests.
|
# schedule the WAITING requests.
|
||||||
if not preempted_reqs:
|
if not preempted_reqs:
|
||||||
@@ -371,6 +366,15 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
if request.status == RequestStatus.WAITING:
|
if request.status == RequestStatus.WAITING:
|
||||||
# Enable prefix caching
|
# Enable prefix caching
|
||||||
if self.config.cache_config.enable_prefix_caching:
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
|
if (
|
||||||
|
self.config.cache_config.enable_hierarchical_cache
|
||||||
|
and self.cache_manager.num_cpu_blocks > 0
|
||||||
|
):
|
||||||
|
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||||
|
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
|
||||||
|
// self.config.cache_config.block_size
|
||||||
|
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
|
||||||
|
break
|
||||||
success = self.get_prefix_cached_blocks(request)
|
success = self.get_prefix_cached_blocks(request)
|
||||||
if not success:
|
if not success:
|
||||||
self._free_blocks(request)
|
self._free_blocks(request)
|
||||||
@@ -389,6 +393,10 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
request.schedule_start_time = time.time()
|
request.schedule_start_time = time.time()
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.num_computed_tokens += num_new_tokens
|
request.num_computed_tokens += num_new_tokens
|
||||||
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
|
self.cache_manager.update_cache_blocks(
|
||||||
|
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||||
|
)
|
||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
main_process_metrics.num_requests_waiting.dec(1)
|
main_process_metrics.num_requests_waiting.dec(1)
|
||||||
main_process_metrics.num_requests_running.inc(1)
|
main_process_metrics.num_requests_running.inc(1)
|
||||||
@@ -406,6 +414,15 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
request.num_total_tokens
|
request.num_total_tokens
|
||||||
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
|
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
|
||||||
if self.config.cache_config.enable_prefix_caching:
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
|
if (
|
||||||
|
self.config.cache_config.enable_hierarchical_cache
|
||||||
|
and self.cache_manager.num_cpu_blocks > 0
|
||||||
|
):
|
||||||
|
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||||
|
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
|
||||||
|
// self.config.cache_config.block_size
|
||||||
|
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
|
||||||
|
break
|
||||||
success = self.get_prefix_cached_blocks(request)
|
success = self.get_prefix_cached_blocks(request)
|
||||||
if not success:
|
if not success:
|
||||||
self._free_blocks(request)
|
self._free_blocks(request)
|
||||||
@@ -421,6 +438,10 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.num_computed_tokens += num_new_tokens
|
request.num_computed_tokens += num_new_tokens
|
||||||
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
|
self.cache_manager.update_cache_blocks(
|
||||||
|
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||||
|
)
|
||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
main_process_metrics.num_requests_waiting.dec(1)
|
main_process_metrics.num_requests_waiting.dec(1)
|
||||||
main_process_metrics.num_requests_running.inc(1)
|
main_process_metrics.num_requests_running.inc(1)
|
||||||
@@ -516,7 +537,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
|
|
||||||
matched_block_num = len(common_block_ids)
|
matched_block_num = len(common_block_ids)
|
||||||
no_cache_block_num = self.cache_manager.get_required_block_num(
|
no_cache_block_num = self.cache_manager.get_required_block_num(
|
||||||
request.prompt_token_ids_len - matched_token_num,
|
request.need_prefill_tokens - matched_token_num,
|
||||||
self.config.cache_config.block_size,
|
self.config.cache_config.block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -532,7 +553,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
|
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
|
||||||
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
|
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
|
||||||
|
|
||||||
if matched_token_num == request.prompt_token_ids_len:
|
if matched_token_num == request.need_prefill_tokens:
|
||||||
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
|
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
|
||||||
request.skip_allocate = True
|
request.skip_allocate = True
|
||||||
else:
|
else:
|
||||||
@@ -550,16 +571,8 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
|
|
||||||
def _free_blocks(self, request: Request):
|
def _free_blocks(self, request: Request):
|
||||||
if self.config.cache_config.enable_prefix_caching:
|
if self.config.cache_config.enable_prefix_caching:
|
||||||
# TODO(chengyanfu): support cache output blocks for prefix caching
|
self.cache_manager.release_block_ids(request)
|
||||||
if request.get("prefill_block_num", None) is None:
|
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cached_block_num :])
|
||||||
leaf_node = self.cache_manager.req_leaf_map[request.request_id]
|
|
||||||
self.cache_manager.decrease_request_share_count(request.request_id)
|
|
||||||
self.cache_manager.free_nodes_directly(leaf_node)
|
|
||||||
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cache_info[0] :])
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.cache_manager.release_block_ids_async(request)
|
|
||||||
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
|
|
||||||
else:
|
else:
|
||||||
self.cache_manager.recycle_gpu_blocks(request.block_tables)
|
self.cache_manager.recycle_gpu_blocks(request.block_tables)
|
||||||
request.block_tables = []
|
request.block_tables = []
|
||||||
|
71
tests/v1/test_prefix_cache.py
Normal file
71
tests/v1/test_prefix_cache.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from dataclasses import asdict
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
|
||||||
|
from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig
|
||||||
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
|
from fastdeploy.engine.request import Request
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_case():
|
||||||
|
max_num_seqs = 3
|
||||||
|
block_size = 64
|
||||||
|
engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=100, max_num_batched_tokens=3200)
|
||||||
|
args = asdict(engine_args)
|
||||||
|
cache_cfg = CacheConfig(args)
|
||||||
|
model_cfg = SimpleNamespace(enable_mm=False)
|
||||||
|
speculative_cfg = SimpleNamespace(method=None)
|
||||||
|
model_cfg.print = print
|
||||||
|
cache_cfg.bytes_per_layer_per_block = 1
|
||||||
|
parallel_cfg = ParallelConfig(args)
|
||||||
|
graph_opt_cfg = engine_args.create_graph_optimization_config()
|
||||||
|
fd_config = FDConfig(
|
||||||
|
model_config=model_cfg,
|
||||||
|
cache_config=cache_cfg,
|
||||||
|
parallel_config=parallel_cfg,
|
||||||
|
graph_opt_config=graph_opt_cfg,
|
||||||
|
speculative_config=speculative_cfg,
|
||||||
|
max_num_batched_tokens=engine_args.max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
cache_manager = PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed")
|
||||||
|
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
|
||||||
|
req2 = Request.from_dict(
|
||||||
|
{"request_id": "req2", "prompt_token_ids": [1] * 1600 + [2] * 1600, "prompt_token_ids_len": 3200}
|
||||||
|
)
|
||||||
|
req3 = Request.from_dict(
|
||||||
|
{"request_id": "req3", "prompt_token_ids": [1] * 1600 + [3] * 1600, "prompt_token_ids_len": 3200}
|
||||||
|
)
|
||||||
|
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size)
|
||||||
|
assert len(common_block_ids) == 0
|
||||||
|
assert matched_token_num == 0
|
||||||
|
assert len(cache_manager.gpu_free_block_list) == 100
|
||||||
|
req1.block_tables.extend(common_block_ids)
|
||||||
|
# allocate for req1 inputs
|
||||||
|
num_new_block = 50
|
||||||
|
req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
|
||||||
|
req1.num_computed_tokens += 50 * block_size
|
||||||
|
cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens)
|
||||||
|
assert len(cache_manager.gpu_free_block_list) == 50
|
||||||
|
# allocate for req2 inputs
|
||||||
|
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size)
|
||||||
|
assert len(common_block_ids) == 25
|
||||||
|
assert matched_token_num == 25 * block_size
|
||||||
|
req2.num_cached_tokens = matched_token_num
|
||||||
|
req2.num_computed_tokens == 25 * block_size
|
||||||
|
num_new_block = 25
|
||||||
|
req2.block_tables.extend(common_block_ids)
|
||||||
|
req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
|
||||||
|
cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens)
|
||||||
|
# allocate for req3 input
|
||||||
|
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size)
|
||||||
|
assert len(common_block_ids) == 25
|
||||||
|
assert matched_token_num == 25 * block_size
|
||||||
|
req3.num_cached_tokens = matched_token_num
|
||||||
|
req3.num_computed_tokens == 25 * block_size
|
||||||
|
assert len(cache_manager.gpu_free_block_list) == 25
|
||||||
|
req3.block_tables.extend(common_block_ids)
|
||||||
|
num_new_block = 25
|
||||||
|
assert cache_manager.can_allocate_gpu_blocks(num_new_block)
|
||||||
|
req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
|
||||||
|
cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens)
|
||||||
|
assert len(cache_manager.gpu_free_block_list) == 0
|
128
tests/v1/test_schedule_output.py
Normal file
128
tests/v1/test_schedule_output.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
from dataclasses import asdict
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig
|
||||||
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
|
from fastdeploy.engine.request import Request
|
||||||
|
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_schedule():
|
||||||
|
max_num_seqs = 3
|
||||||
|
engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=160, max_num_batched_tokens=3200)
|
||||||
|
args = asdict(engine_args)
|
||||||
|
cache_cfg = CacheConfig(args)
|
||||||
|
model_cfg = SimpleNamespace(enable_mm=False)
|
||||||
|
speculative_cfg = SimpleNamespace(method=None)
|
||||||
|
model_cfg.print = print
|
||||||
|
cache_cfg.bytes_per_layer_per_block = 1
|
||||||
|
parallel_cfg = ParallelConfig(args)
|
||||||
|
graph_opt_cfg = engine_args.create_graph_optimization_config()
|
||||||
|
fd_config = FDConfig(
|
||||||
|
model_config=model_cfg,
|
||||||
|
cache_config=cache_cfg,
|
||||||
|
parallel_config=parallel_cfg,
|
||||||
|
speculative_config=speculative_cfg,
|
||||||
|
graph_opt_config=graph_opt_cfg,
|
||||||
|
max_num_batched_tokens=engine_args.max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
resource_manager_v1 = ResourceManagerV1(
|
||||||
|
max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed"
|
||||||
|
)
|
||||||
|
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3199, "prompt_token_ids_len": 3199})
|
||||||
|
req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3201, "prompt_token_ids_len": 3201})
|
||||||
|
req3 = Request.from_dict({"request_id": "req3", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
|
||||||
|
resource_manager_v1.add_request(req1)
|
||||||
|
resource_manager_v1.add_request(req2)
|
||||||
|
resource_manager_v1.add_request(req3)
|
||||||
|
# step 1
|
||||||
|
assert len(resource_manager_v1.waiting) == 3
|
||||||
|
scheduler_reqs = resource_manager_v1.schedule()
|
||||||
|
assert len(scheduler_reqs) == 2
|
||||||
|
assert scheduler_reqs[0].request_id == "req1"
|
||||||
|
assert scheduler_reqs[1].request_id == "req2"
|
||||||
|
assert scheduler_reqs[0].prefill_start_index == 0
|
||||||
|
assert scheduler_reqs[1].prefill_start_index == 0
|
||||||
|
assert scheduler_reqs[0].prefill_end_index == 3199
|
||||||
|
assert scheduler_reqs[1].prefill_end_index == 1
|
||||||
|
assert len(resource_manager_v1.running) == 2
|
||||||
|
assert len(resource_manager_v1.waiting) == 1
|
||||||
|
# step 2
|
||||||
|
scheduler_reqs = resource_manager_v1.schedule()
|
||||||
|
assert len(scheduler_reqs) == 2
|
||||||
|
assert scheduler_reqs[0].request_id == "req1"
|
||||||
|
assert len(scheduler_reqs[0].block_tables) == 52
|
||||||
|
assert scheduler_reqs[1].request_id == "req2"
|
||||||
|
assert scheduler_reqs[1].prefill_start_index == 1
|
||||||
|
assert scheduler_reqs[1].prefill_end_index == 3200
|
||||||
|
assert len(resource_manager_v1.running) == 2
|
||||||
|
assert len(resource_manager_v1.waiting) == 1
|
||||||
|
# step 3
|
||||||
|
scheduler_reqs = resource_manager_v1.schedule()
|
||||||
|
assert len(scheduler_reqs) == 2
|
||||||
|
assert scheduler_reqs[0].request_id == "req2"
|
||||||
|
assert scheduler_reqs[0].prefill_start_index == 3200
|
||||||
|
assert scheduler_reqs[0].prefill_end_index == 3201
|
||||||
|
assert scheduler_reqs[1].request_id == "req3"
|
||||||
|
assert scheduler_reqs[1].prefill_start_index == 0
|
||||||
|
assert scheduler_reqs[1].prefill_end_index == 3199
|
||||||
|
assert len(resource_manager_v1.running) == 3
|
||||||
|
assert len(resource_manager_v1.waiting) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_preempted_request():
|
||||||
|
max_num_seqs = 2
|
||||||
|
engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=52, max_num_batched_tokens=3200)
|
||||||
|
args = asdict(engine_args)
|
||||||
|
cache_cfg = CacheConfig(args)
|
||||||
|
model_cfg = SimpleNamespace(enable_mm=False)
|
||||||
|
speculative_cfg = SimpleNamespace(method=None)
|
||||||
|
model_cfg.print = print
|
||||||
|
cache_cfg.bytes_per_layer_per_block = 1
|
||||||
|
parallel_cfg = ParallelConfig(args)
|
||||||
|
graph_opt_cfg = engine_args.create_graph_optimization_config()
|
||||||
|
fd_config = FDConfig(
|
||||||
|
model_config=model_cfg,
|
||||||
|
cache_config=cache_cfg,
|
||||||
|
parallel_config=parallel_cfg,
|
||||||
|
graph_opt_config=graph_opt_cfg,
|
||||||
|
speculative_config=speculative_cfg,
|
||||||
|
max_num_batched_tokens=engine_args.max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
resource_manager_v1 = ResourceManagerV1(
|
||||||
|
max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed"
|
||||||
|
)
|
||||||
|
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
|
||||||
|
req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
|
||||||
|
resource_manager_v1.add_request(req1)
|
||||||
|
resource_manager_v1.add_request(req2)
|
||||||
|
# step 1
|
||||||
|
assert len(resource_manager_v1.waiting) == 2
|
||||||
|
scheduler_reqs = resource_manager_v1.schedule()
|
||||||
|
assert len(scheduler_reqs) == 1
|
||||||
|
assert scheduler_reqs[0].request_id == "req1"
|
||||||
|
assert scheduler_reqs[0].prefill_start_index == 0
|
||||||
|
assert scheduler_reqs[0].prefill_end_index == 3200
|
||||||
|
assert len(resource_manager_v1.running) == 1
|
||||||
|
assert len(resource_manager_v1.waiting) == 1
|
||||||
|
# step 2
|
||||||
|
scheduler_reqs = resource_manager_v1.schedule()
|
||||||
|
assert len(scheduler_reqs) == 1
|
||||||
|
assert scheduler_reqs[0].request_id == "req1"
|
||||||
|
assert len(scheduler_reqs[0].block_tables) == 52
|
||||||
|
# step 3
|
||||||
|
req1.output_token_ids.extend([1] * 128)
|
||||||
|
scheduler_reqs = resource_manager_v1.schedule()
|
||||||
|
assert len(scheduler_reqs) == 1
|
||||||
|
assert scheduler_reqs[0].request_id == "req1"
|
||||||
|
assert len(resource_manager_v1.running) == 0
|
||||||
|
# to be added into waiting queue
|
||||||
|
assert len(resource_manager_v1.waiting) == 1
|
||||||
|
# mock token_processor to add into waiting
|
||||||
|
resource_manager_v1.waiting.appendleft(req1)
|
||||||
|
# step 4
|
||||||
|
scheduler_reqs = resource_manager_v1.schedule()
|
||||||
|
assert len(scheduler_reqs) == 1
|
||||||
|
assert scheduler_reqs[0].request_id == "req1"
|
||||||
|
assert len(resource_manager_v1.running) == 1
|
||||||
|
assert len(resource_manager_v1.waiting) == 1
|
Reference in New Issue
Block a user