From 85db9d5e56f9a026d548549f4e4f9229878c0881 Mon Sep 17 00:00:00 2001 From: ming1753 <61511741+ming1753@users.noreply.github.com> Date: Tue, 23 Dec 2025 20:45:52 +0800 Subject: [PATCH] [Others] reschedule preempt task support optional func (#5649) * [Others] reschedule preempt task support optional func * fix bug * fix bug --- fastdeploy/engine/request.py | 1 + .../engine/sched/resource_manager_v1.py | 30 +++++++++++++++---- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 1c02a9ac9..ddc849dab 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -158,6 +158,7 @@ class Request: self.task_type = RequestType.PREFILL self.idx = None self.need_prefill_tokens = self.prompt_token_ids_len + self.audio_output_token_ids = [] # extend block tables self.use_extend_tables = False self.extend_block_tables = [] diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index b25f90b47..3d08d7389 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -224,10 +224,12 @@ class ResourceManagerV1(ResourceManager): def _prepare_preempt_task(self, request): return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id) - def reschedule_preempt_task(self, request_id): + def reschedule_preempt_task(self, request_id, process_func=None): with self.lock: if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests: request = self.requests[request_id] + if process_func is not None: + process_func(request) self.waiting.appendleft(request) self.to_be_rescheduled_request_id_set.remove(request_id) @@ -368,7 +370,25 @@ class ResourceManagerV1(ResourceManager): new_end_idx = pre_end_idx + num_new_tokens prompt_token_ids_len = len(request.prompt_token_ids) - assert prompt_token_ids_len == len(inputs["patch_idx"]), (prompt_token_ids_len, len(inputs["patch_idx"])) + if not inputs.get("tts", False): + assert prompt_token_ids_len == len(inputs["patch_idx"]), ( + prompt_token_ids_len, + len(inputs["patch_idx"]), + ) + + def _compute_audio_prefix_count(end_idx, end_patch_idx): + audio_prefix_count = 0 + pre_patch_end_idx = 0 + for patch_idx in range(end_patch_idx + 1): + patch_map = inputs["patch_map"][patch_idx] + modal_id = patch_map["modal_id"] + if modal_id == IDS_TYPE_FLAG["audio"]: + if patch_idx != end_patch_idx: + audio_prefix_count += patch_map["end_idx"] - pre_patch_end_idx + else: + audio_prefix_count += end_idx - pre_patch_end_idx + pre_patch_end_idx = patch_map["end_idx"] + return audio_prefix_count # start if pre_end_idx >= prompt_token_ids_len: @@ -378,7 +398,7 @@ class ResourceManagerV1(ResourceManager): start_patch_map = inputs["patch_map"][start_patch_idx] request.image_start = start_patch_map["image_num"] request.video_start = start_patch_map["video_num"] - request.audio_start = start_patch_map["audio_num"] + request.audio_start = _compute_audio_prefix_count(pre_end_idx, start_patch_idx) # end if new_end_idx >= prompt_token_ids_len: @@ -393,7 +413,7 @@ class ResourceManagerV1(ResourceManager): end_patch_idx -= 1 end_patch_map = inputs["patch_map"][end_patch_idx] end_modal_id = end_patch_map["modal_id"] - if end_modal_id > 0 and end_modal_id != IDS_TYPE_FLAG["video"]: + if end_modal_id == IDS_TYPE_FLAG["image"]: new_end_idx = end_patch_map["end_idx"] # 当前模态结束位置 if end_modal_id == IDS_TYPE_FLAG["video"] and "can_split_idx_list" in inputs: @@ -406,7 +426,7 @@ class ResourceManagerV1(ResourceManager): request.image_end = end_patch_map["image_num"] request.video_end = end_patch_map["video_num"] - request.audio_end = end_patch_map["audio_num"] + request.audio_end = _compute_audio_prefix_count(new_end_idx, end_patch_idx) elif ( inputs.get("images", None) is not None and inputs.get("image_patch_id", None) is not None