【Feature】ResourceManagerV1 support need block num notifying (#4220)

* support need block num notifying

* adapt t2i

* fix unexpected change
This commit is contained in:
RichardWooSJTU
2025-09-29 11:11:51 +08:00
committed by GitHub
parent 70633c6641
commit 3740e33fea
3 changed files with 211 additions and 61 deletions

View File

@@ -29,6 +29,7 @@ import paddle
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger
@@ -69,6 +70,69 @@ class ScheduledExtendBlocksTask:
task_type: RequestType = RequestType.EXTEND
class SignalConsumer:
"""
A class that consumes a signal value up to a specified limit.
This class maintains an internal signal value and allows controlled consumption
of that signal. The signal can be watched at any time, but can only be consumed
a limited number of times before being reset to zero.
"""
def __init__(self, signal, consume_limit):
"""
Initialize the SignalConsumer with a signal value and consumption limit.
Args:
signal: The initial signal value to be consumed.
consume_limit (int): The maximum number of times the signal can be consumed
before being reset to 0. Must be a positive integer.
Raises:
AssertionError: If consume_limit is not greater than 0.
"""
assert consume_limit > 0
self._signal = signal
self._consume_limit = consume_limit
def watch(self):
"""
Get the current signal value without consuming it.
This method allows reading the signal value any number of times without
affecting the consumption limit or the signal value itself.
Returns:
The current signal value.
"""
return self._signal
def consume(self):
"""
Consume the signal value, decrementing the consumption limit.
This method returns the current signal value and decrements the consumption
counter. When the consumption limit reaches zero, the signal is automatically
reset to 0. The consumption happens in a finally block to ensure the limit is
decremented even if an exception occurs while processing the signal.
Returns:
The current signal value before consumption.
Note:
After the consumption limit is reached, this method will continue to
return 0 on subsequent calls.
"""
try:
return self._signal
finally:
if self._consume_limit > 0:
self._consume_limit -= 1
if self._consume_limit == 0:
self._signal = 0
class ResourceManagerV1(ResourceManager):
"""
Resource manager for scheduler v1.
@@ -95,6 +159,19 @@ class ResourceManagerV1(ResourceManager):
main_process_metrics.max_batch_size.set(max_num_seqs)
self.using_extend_tables_req_id = set()
self.reuse_block_num_map = dict()
# need block nums
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
self.need_block_num_signal = IPCSignal(
name="need_block_num_signal",
array=need_block_num_data,
dtype=np.int32,
suffix=local_data_parallel_id,
create=True,
)
self.need_block_num_map = dict()
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
@@ -127,14 +204,35 @@ class ResourceManagerV1(ResourceManager):
self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id)
def _info_each_block(self):
"""
print each req block
"""
for req in self.running:
llm_logger.debug(
f"req idx {req.idx} occupy {len(req.block_tables)} block_tables and {len(req.extend_block_tables)} extend_block_tables"
)
def _can_preempt(self):
"""
cannot preempt request which use extend block
"""
for req in self.running:
if not req.use_extend_tables:
return True
return False
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
"""
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
"""
can_schedule = True
while True:
can_schedule = False
while self._can_preempt():
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
preempted_req = self.running.pop()
if preempted_req.use_extend_tables:
self.running.insert(0, preempted_req)
continue
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.config.scheduler_config.splitwise_role == "decode":
@@ -156,6 +254,13 @@ class ResourceManagerV1(ResourceManager):
main_process_metrics.num_requests_running.dec(1)
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
llm_logger.debug(
f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}"
)
llm_logger.debug(self.info())
self._info_each_block()
if preempted_req == request:
# No more request to preempt.
can_schedule = False
@@ -314,6 +419,11 @@ class ResourceManagerV1(ResourceManager):
num_decoding_req_nums = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
need_block_num = self.need_block_num_signal.value[request.idx]
if need_block_num != 0:
self.need_block_num_map[request.request_id] = SignalConsumer(need_block_num, 1)
self.need_block_num_signal.value[request.idx] = 0
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if (
self.config.scheduler_config.splitwise_role == "prefill"
@@ -351,6 +461,60 @@ class ResourceManagerV1(ResourceManager):
scheduled_reqs.append(self._prepare_decode_task(request))
num_decoding_req_nums += 1
token_budget -= 1
if (
request.use_extend_tables
and request.request_id not in self.using_extend_tables_req_id
and self.need_block_num_map[request.request_id].watch() > 0
):
def _allocate_decode_and_extend():
allocate_block_num = self.need_block_num_map[request.request_id].consume()
# Prepare decoding task
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(allocate_block_num))
scheduled_reqs.append(self._prepare_decode_task(request))
# Prepare extend task
reuse_block_num = request.num_total_tokens // self.config.cache_config.block_size
llm_logger.info(
f"req {request.request_id} at batch id {request.idx} with reuse_block_num {reuse_block_num} is going to enable extend tables,"
f"need_block_num {allocate_block_num}"
)
self.using_extend_tables_req_id.add(request.request_id)
self.reuse_block_num_map[request.request_id] = reuse_block_num
request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache
request.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(allocate_block_num)
)
scheduled_reqs.append(
ScheduledExtendBlocksTask(
idx=request.idx,
request_id=request.request_id,
extend_block_tables=request.extend_block_tables,
)
)
llm_logger.debug(f"extend blocks is {request.extend_block_tables}")
if self.cache_manager.can_allocate_gpu_blocks(
2 * self.need_block_num_map[request.request_id].watch()
):
_allocate_decode_and_extend()
else:
llm_logger.info(
f"{request.idx} using extend block need {2 * self.need_block_num_map[request.request_id].watch()} blocks but got not enough blocks, ready to preempt"
)
can_schedule = self._trigger_preempt(
request,
2 * self.need_block_num_map[request.request_id].watch(),
preempted_reqs,
scheduled_reqs,
)
if can_schedule:
_allocate_decode_and_extend()
else:
break
else: # need to prefill
llm_logger.debug(
f"scheduler prefill task: {request} request.need_prefill_tokens {request.need_prefill_tokens} request.num_computed_tokens {request.num_computed_tokens}"
@@ -476,56 +640,6 @@ class ResourceManagerV1(ResourceManager):
else:
llm_logger.error("Unknown request status type")
# schedule when extend block tables is needed
for req in self.running:
num_prefill_blocks = req.need_prefill_tokens // self.config.cache_config.block_size
# allocate
if req.use_extend_tables and req.request_id not in self.using_extend_tables_req_id:
llm_logger.info(
f"req {req.request_id} at batch id {req.idx} with num_prefill_blocks {num_prefill_blocks} is going to enable extend tables"
)
self.using_extend_tables_req_id.add(req.request_id)
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
req.extend_block_tables = req.block_tables[:num_prefill_blocks] # copy prompt cache
req.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
)
scheduled_reqs.append(
ScheduledExtendBlocksTask(
idx=req.idx, request_id=req.request_id, extend_block_tables=req.extend_block_tables
)
)
llm_logger.info(f"extend blocks is {req.extend_block_tables}")
else:
continue
# recycle
elif not req.use_extend_tables and req.request_id in self.using_extend_tables_req_id:
llm_logger.info(f"req {req.request_id} is going to disable extend tables")
self.using_extend_tables_req_id.remove(req.request_id)
self.cache_manager.recycle_gpu_blocks(req.extend_block_tables[num_prefill_blocks:])
req.extend_block_tables = []
# allocate extend blocks when blocks is going to exhaust
elif req.request_id in self.using_extend_tables_req_id:
if (
self.allocated_slots(req) - req.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
):
llm_logger.info(
f"req {req.request_id} is going to allocate more extend tables because allocated_slots {self.allocated_slots(req)} and prealloc_dec_block_slot_num_threshold {self.config.cache_config.prealloc_dec_block_slot_num_threshold} req.num_total_tokens {req.num_total_tokens}"
)
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
req.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
)
scheduled_reqs.append(
ScheduledExtendBlocksTask(
idx=req.idx, request_id=req.request_id, extend_block_tables=req.extend_block_tables
)
)
else:
continue
if scheduled_reqs:
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
@@ -725,13 +839,16 @@ class ResourceManagerV1(ResourceManager):
request.block_tables = []
if request.request_id in self.using_extend_tables_req_id:
num_prefill_blocks = request.need_prefill_tokens // self.config.cache_config.block_size
reuse_block_num = self.reuse_block_num_map[request.request_id]
self.using_extend_tables_req_id.remove(request.request_id)
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[num_prefill_blocks:])
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:])
llm_logger.info(
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[num_prefill_blocks:]}"
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}"
)
request.extend_block_tables = []
del self.reuse_block_num_map[request.request_id]
del self.need_block_num_map[request.request_id]
def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
return self.finish_execution_pool.submit(self.finish_requests, request_ids)