support extend block tables (#3824)

This commit is contained in:
RichardWooSJTU
2025-09-04 14:39:04 +08:00
committed by GitHub
parent 6ef3b611b0
commit 0989788b29
2 changed files with 78 additions and 0 deletions

View File

@@ -40,6 +40,7 @@ class RequestType(Enum):
PREFILL = 0 PREFILL = 0
DECODE = 1 DECODE = 1
PREEMPTED = 2 PREEMPTED = 2
EXTEND = 3
@dataclass @dataclass
@@ -141,6 +142,9 @@ class Request:
self.task_type = RequestType.PREFILL self.task_type = RequestType.PREFILL
self.idx = None self.idx = None
self.need_prefill_tokens = self.prompt_token_ids_len self.need_prefill_tokens = self.prompt_token_ids_len
# extend block tables
self.use_extend_tables = False
self.extend_block_tables = []
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):

View File

@@ -55,6 +55,18 @@ class ScheduledPreemptTask:
task_type: RequestType = RequestType.PREEMPTED task_type: RequestType = RequestType.PREEMPTED
@dataclass
class ScheduledExtendBlocksTask:
"""
Task for allocating new blocks to extend.
"""
idx: int
request_id: str
extend_block_tables: list[int]
task_type: RequestType = RequestType.EXTEND
class ResourceManagerV1(ResourceManager): class ResourceManagerV1(ResourceManager):
""" """
Resource manager for scheduler v1. Resource manager for scheduler v1.
@@ -80,6 +92,8 @@ class ResourceManagerV1(ResourceManager):
self.to_be_rescheduled_request_id_set = set() self.to_be_rescheduled_request_id_set = set()
main_process_metrics.max_batch_size.set(max_num_seqs) main_process_metrics.max_batch_size.set(max_num_seqs)
self.using_extend_tables_req_id = set()
def allocated_slots(self, request: Request): def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size return len(request.block_tables) * self.config.cache_config.block_size
@@ -405,6 +419,57 @@ class ResourceManagerV1(ResourceManager):
break break
else: else:
llm_logger.error("Unknown request status type") llm_logger.error("Unknown request status type")
# schedule when extend block tables is needed
for req in self.running:
num_prefill_blocks = req.need_prefill_tokens // self.config.cache_config.block_size
# alocate
if req.use_extend_tables and req.request_id not in self.using_extend_tables_req_id:
llm_logger.info(
f"req {req.request_id} at batch id {req.idx} with num_prefill_blocks {num_prefill_blocks} is going to enable extend tables"
)
self.using_extend_tables_req_id.add(req.request_id)
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
req.extend_block_tables = req.block_tables[:num_prefill_blocks] # copy prompt cache
req.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
)
scheduled_reqs.append(
ScheduledExtendBlocksTask(
idx=req.idx, request_id=req.request_id, extend_block_tables=req.extend_block_tables
)
)
llm_logger.info(f"extend blocks is {req.extend_block_tables}")
else:
continue
# recycle
elif not req.use_extend_tables and req.request_id in self.using_extend_tables_req_id:
llm_logger.info(f"req {req.request_id} is going to disable extend tables")
self.using_extend_tables_req_id.remove(req.request_id)
self.cache_manager.recycle_gpu_blocks(req.extend_block_tables[num_prefill_blocks:])
req.extend_block_tables = []
# allocate extend blocks when blocks is going to exhaust
elif req.request_id in self.using_extend_tables_req_id:
if (
self.allocated_slots(req) - req.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
):
llm_logger.info(
f"req {req.request_id} is going to alocate more extend tables because allocated_slots {self.allocated_slots(req)} and prealloc_dec_block_slot_num_threshold {self.config.cache_config.prealloc_dec_block_slot_num_threshold} req.num_total_tokens {req.num_total_tokens}"
)
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
req.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
)
scheduled_reqs.append(
ScheduledExtendBlocksTask(
idx=req.idx, request_id=req.request_id, extend_block_tables=req.extend_block_tables
)
)
else:
continue
if scheduled_reqs: if scheduled_reqs:
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list]) task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num) main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
@@ -488,6 +553,15 @@ class ResourceManagerV1(ResourceManager):
self.cache_manager.recycle_gpu_blocks(request.block_tables) self.cache_manager.recycle_gpu_blocks(request.block_tables)
request.block_tables = [] request.block_tables = []
if request.request_id in self.using_extend_tables_req_id:
num_prefill_blocks = request.need_prefill_tokens // self.config.cache_config.block_size
self.using_extend_tables_req_id.remove(request.request_id)
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[num_prefill_blocks:])
llm_logger.info(
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[num_prefill_blocks:]}"
)
request.extend_block_tables = []
def finish_requests_async(self, request_ids: Union[str, Iterable[str]]): def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
return self.finish_execution_pool.submit(self.finish_requests, request_ids) return self.finish_execution_pool.submit(self.finish_requests, request_ids)