diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 764e71de7..3b666acc6 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -75,6 +75,7 @@ class ResourceManagerV1(ResourceManager): self.running: list[Request] = [] self.finish_execution_pool = ThreadPoolExecutor(max_workers=1) self.lock = threading.Lock() + self.to_be_rescheduled_request_id_set = set() def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -96,6 +97,13 @@ class ResourceManagerV1(ResourceManager): def _prepare_preempt_task(self, request): return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id) + + def reschedule_preempt_task(self, request_id): + with self.lock: + if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests: + request = self.requests[request_id] + self.waiting.appendleft(request) + self.to_be_rescheduled_request_id_set.remove(request_id) def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): can_schedule = True @@ -106,7 +114,7 @@ class ResourceManagerV1(ResourceManager): preempted_req.num_computed_tokens = 0 preempted_req.prefill_block_num = 0 self._free_blocks(preempted_req) - self.waiting.appendleft(preempted_req) + self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) preempted_reqs.append(preempted_req) scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) if preempted_req == request: @@ -381,8 +389,9 @@ class ResourceManagerV1(ResourceManager): return False def add_request(self, request: Request) -> None: - self.waiting.append(request) - self.requests[request.request_id] = request + with self.lock: + self.waiting.append(request) + self.requests[request.request_id] = request def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching: @@ -409,9 +418,15 @@ class ResourceManagerV1(ResourceManager): if request is None: # Invalid request ID. continue - request.status = RequestStatus.FINISHED - self.running.remove(request) - self._free_blocks(request) + if request in self.running: # normally run and finished + self.running.remove(request) + request.status = RequestStatus.FINISHED + self._free_blocks(request) + if request.request_id in self.to_be_rescheduled_request_id_set: # finished after preempted, blocks have been recycled. + self.to_be_rescheduled_request_id_set.remove(request.request_id) # just remove from to_be_rescheduled_request_id_set + if request in self.waiting: # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here + raise RuntimeError(f"request {request.request_id} scheduled into waiting list, after finished") + self.tasks_list[request.idx] = None self.stop_flags[request.idx] = True del self.requests[req_id] diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index bb8e5c447..000c4c0dc 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -431,8 +431,13 @@ class TokenProcessor: else: batch = self.output_tokens[1, 0] tokens = tokens[2 : batch + 2] - + batch_result = list() + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set) + for request_id in need_to_be_reschedule_req_ids: + if self.resource_manager.requests[request_id].idx >= (batch - 1): # No more token generated for preempted request + self.resource_manager.reschedule_preempt_task(request_id) for i in range(batch): if self.resource_manager.stop_flags[i]: continue @@ -459,6 +464,8 @@ class TokenProcessor: if recovery_stop: llm_logger.info(f"recovery stop signal found at task {task_id}") if not recovery_stop and token_id < 0: + if task_id in self.resource_manager.to_be_rescheduled_request_id_set: + self.resource_manager.reschedule_preempt_task(task_id) continue if task.get("prefill_chunk_info", None) is not None: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 7a149f83d..590c1e2e7 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -215,11 +215,11 @@ class GPUModelRunner(ModelRunnerBase): req_len = len(req_dicts) has_prefill_task = False + has_decode_task = False for i in range(req_len): request = req_dicts[i] idx = request.idx if request.task_type.value == RequestType.PREFILL.value: # prefill task - logger.debug(f"Handle prefill request {request} at idx {idx}") prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index @@ -265,6 +265,7 @@ class GPUModelRunner(ModelRunnerBase): ) input_ids = request.prompt_token_ids + request.output_token_ids + logger.debug(f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}") self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( input_ids[prefill_start_index:prefill_end_index] ) @@ -293,6 +294,8 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) + if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode + has_decode_task = True continue else: # preempted task logger.debug(f"Handle preempted request {request} at idx {idx}") @@ -338,7 +341,7 @@ class GPUModelRunner(ModelRunnerBase): else: self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 - if has_prefill_task: + if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]