[Others] reschedule preempt task support optional func (#5649)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* [Others] reschedule preempt task support optional func

* fix bug

* fix bug
This commit is contained in:
ming1753
2025-12-23 20:45:52 +08:00
committed by GitHub
parent 5cec66adb8
commit 85db9d5e56
2 changed files with 26 additions and 5 deletions

View File

@@ -158,6 +158,7 @@ class Request:
self.task_type = RequestType.PREFILL
self.idx = None
self.need_prefill_tokens = self.prompt_token_ids_len
self.audio_output_token_ids = []
# extend block tables
self.use_extend_tables = False
self.extend_block_tables = []

View File

@@ -224,10 +224,12 @@ class ResourceManagerV1(ResourceManager):
def _prepare_preempt_task(self, request):
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
def reschedule_preempt_task(self, request_id):
def reschedule_preempt_task(self, request_id, process_func=None):
with self.lock:
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
request = self.requests[request_id]
if process_func is not None:
process_func(request)
self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id)
@@ -368,7 +370,25 @@ class ResourceManagerV1(ResourceManager):
new_end_idx = pre_end_idx + num_new_tokens
prompt_token_ids_len = len(request.prompt_token_ids)
assert prompt_token_ids_len == len(inputs["patch_idx"]), (prompt_token_ids_len, len(inputs["patch_idx"]))
if not inputs.get("tts", False):
assert prompt_token_ids_len == len(inputs["patch_idx"]), (
prompt_token_ids_len,
len(inputs["patch_idx"]),
)
def _compute_audio_prefix_count(end_idx, end_patch_idx):
audio_prefix_count = 0
pre_patch_end_idx = 0
for patch_idx in range(end_patch_idx + 1):
patch_map = inputs["patch_map"][patch_idx]
modal_id = patch_map["modal_id"]
if modal_id == IDS_TYPE_FLAG["audio"]:
if patch_idx != end_patch_idx:
audio_prefix_count += patch_map["end_idx"] - pre_patch_end_idx
else:
audio_prefix_count += end_idx - pre_patch_end_idx
pre_patch_end_idx = patch_map["end_idx"]
return audio_prefix_count
# start
if pre_end_idx >= prompt_token_ids_len:
@@ -378,7 +398,7 @@ class ResourceManagerV1(ResourceManager):
start_patch_map = inputs["patch_map"][start_patch_idx]
request.image_start = start_patch_map["image_num"]
request.video_start = start_patch_map["video_num"]
request.audio_start = start_patch_map["audio_num"]
request.audio_start = _compute_audio_prefix_count(pre_end_idx, start_patch_idx)
# end
if new_end_idx >= prompt_token_ids_len:
@@ -393,7 +413,7 @@ class ResourceManagerV1(ResourceManager):
end_patch_idx -= 1
end_patch_map = inputs["patch_map"][end_patch_idx]
end_modal_id = end_patch_map["modal_id"]
if end_modal_id > 0 and end_modal_id != IDS_TYPE_FLAG["video"]:
if end_modal_id == IDS_TYPE_FLAG["image"]:
new_end_idx = end_patch_map["end_idx"] # 当前模态结束位置
if end_modal_id == IDS_TYPE_FLAG["video"] and "can_split_idx_list" in inputs:
@@ -406,7 +426,7 @@ class ResourceManagerV1(ResourceManager):
request.image_end = end_patch_map["image_num"]
request.video_end = end_patch_map["video_num"]
request.audio_end = end_patch_map["audio_num"]
request.audio_end = _compute_audio_prefix_count(new_end_idx, end_patch_idx)
elif (
inputs.get("images", None) is not None
and inputs.get("image_patch_id", None) is not None