diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index f5bdcceb0..201273e1e 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -28,6 +28,7 @@ import paddle from fastdeploy.engine.request import Request, RequestStatus, RequestType from fastdeploy.engine.resource_manager import ResourceManager +from fastdeploy.input.utils import IDS_TYPE_FLAG from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import llm_logger @@ -151,6 +152,24 @@ class ResourceManagerV1(ResourceManager): new_end_idx = pre_end_idx + num_new_tokens prompt_token_ids_len = len(request.prompt_token_ids) + + if new_end_idx >= prompt_token_ids_len: + return num_new_tokens + + if inputs.get("can_split_idx_list") is not None: + if new_end_idx >= prompt_token_ids_len: + return num_new_tokens + patch_idx = inputs["patch_idx"][new_end_idx] + patch_map = inputs["patch_map"][patch_idx] + modal_id = patch_map["modal_id"] + if modal_id == IDS_TYPE_FLAG["text"]: + return num_new_tokens + elif modal_id == IDS_TYPE_FLAG["video"]: + can_split_idx_list = inputs["can_split_idx_list"] + for i in range(len(can_split_idx_list)): + if can_split_idx_list[i] >= new_end_idx: + return can_split_idx_list[i] - pre_end_idx + assert prompt_token_ids_len == len(inputs["patch_idx"]), (prompt_token_ids_len, len(inputs["patch_idx"])) # start