diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 69f8516c3..a3c2dbbb9 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -118,6 +118,7 @@ class Request: self.status = RequestStatus.WAITING self.task_type = RequestType.PREFILL self.idx = None + self.need_prefill_tokens = self.prompt_token_ids_len @classmethod def from_dict(cls, d: dict): diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 5dc878895..051e985db 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -117,11 +117,8 @@ class ResourceManagerV1(ResourceManager): break return can_schedule - def _get_num_new_tokens(self, request, token_budget, schedule_waiting=False): - if schedule_waiting: - num_new_tokens = request.num_total_tokens - request.num_computed_tokens - else: - num_new_tokens = request.prompt_token_ids_len - request.num_computed_tokens + def _get_num_new_tokens(self, request, token_budget): + num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) if not self.config.enable_mm: @@ -212,8 +209,8 @@ class ResourceManagerV1(ResourceManager): num_decoding_req_nums = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if request.num_computed_tokens >= request.prompt_token_ids_len: # to be decoding - if request.num_total_tokens > request.prompt_token_ids_len: # has generated tokens + if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding + if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens request.num_computed_tokens = request.num_total_tokens - 1 if ( self.allocated_slots(request) - request.num_total_tokens @@ -246,7 +243,7 @@ class ResourceManagerV1(ResourceManager): token_budget -= 1 else: # need to prefill llm_logger.debug( - f"scheduler prefill task: {request} request.prompt_token_ids_len {request.prompt_token_ids_len} request.num_computed_tokens {request.num_computed_tokens}" + f"scheduler prefill task: {request} request.need_prefill_tokens {request.need_prefill_tokens} request.num_computed_tokens {request.num_computed_tokens}" ) num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) @@ -274,7 +271,7 @@ class ResourceManagerV1(ResourceManager): break request = self.waiting[0] if request.status == RequestStatus.WAITING: - num_new_tokens = self._get_num_new_tokens(request, token_budget, True) + num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block): @@ -295,7 +292,8 @@ class ResourceManagerV1(ResourceManager): else: break elif request.status == RequestStatus.PREEMPTED: - num_new_tokens = self._get_num_new_tokens(request, token_budget, True) + request.need_prefill_tokens = request.num_total_tokens # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct + num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block):