[BUG FIX] Fix bug when preempted request rescheduled (#3080)

* Fix bug when preempted request rescheduled

* Fix bug when preempted request rescheduled

* Fix bug when preempted request rescheduled
This commit is contained in:
chenjian
2025-07-30 22:25:47 +08:00
committed by GitHub
parent 0616c208d2
commit fe0e3f508b
2 changed files with 9 additions and 10 deletions

View File

@@ -118,6 +118,7 @@ class Request:
self.status = RequestStatus.WAITING
self.task_type = RequestType.PREFILL
self.idx = None
self.need_prefill_tokens = self.prompt_token_ids_len
@classmethod
def from_dict(cls, d: dict):

View File

@@ -117,11 +117,8 @@ class ResourceManagerV1(ResourceManager):
break
return can_schedule
def _get_num_new_tokens(self, request, token_budget, schedule_waiting=False):
if schedule_waiting:
num_new_tokens = request.num_total_tokens - request.num_computed_tokens
else:
num_new_tokens = request.prompt_token_ids_len - request.num_computed_tokens
def _get_num_new_tokens(self, request, token_budget):
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
if not self.config.enable_mm:
@@ -212,8 +209,8 @@ class ResourceManagerV1(ResourceManager):
num_decoding_req_nums = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if request.num_computed_tokens >= request.prompt_token_ids_len: # to be decoding
if request.num_total_tokens > request.prompt_token_ids_len: # has generated tokens
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1
if (
self.allocated_slots(request) - request.num_total_tokens
@@ -246,7 +243,7 @@ class ResourceManagerV1(ResourceManager):
token_budget -= 1
else: # need to prefill
llm_logger.debug(
f"scheduler prefill task: {request} request.prompt_token_ids_len {request.prompt_token_ids_len} request.num_computed_tokens {request.num_computed_tokens}"
f"scheduler prefill task: {request} request.need_prefill_tokens {request.need_prefill_tokens} request.num_computed_tokens {request.num_computed_tokens}"
)
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
@@ -274,7 +271,7 @@ class ResourceManagerV1(ResourceManager):
break
request = self.waiting[0]
if request.status == RequestStatus.WAITING:
num_new_tokens = self._get_num_new_tokens(request, token_budget, True)
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
@@ -295,7 +292,8 @@ class ResourceManagerV1(ResourceManager):
else:
break
elif request.status == RequestStatus.PREEMPTED:
num_new_tokens = self._get_num_new_tokens(request, token_budget, True)
request.need_prefill_tokens = request.num_total_tokens # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):