[Feature] enable guided decoding ENABLE_V1_KVCACHE_SCHEDULER = 1 (#5140)

* enable guided decoding ENABLE_V1_KVCACHE_SCHEDULER = 1

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Daci
2025-11-26 10:22:35 +08:00
committed by GitHub
parent 2d787590c4
commit f25ee3a26f
3 changed files with 38 additions and 5 deletions

View File

@@ -521,7 +521,28 @@ class GPUModelRunner(ModelRunnerBase):
if hasattr(request, "pooling_params") and request.pooling_params is not None:
batch_pooling_params.append(request.pooling_params)
logits_info = None
prefill_tokens = []
if request.task_type.value == RequestType.PREFILL.value: # prefill task
# guided decoding
if (
request.guided_json is not None
or request.guided_regex is not None
or request.structural_tag is not None
or request.guided_grammar is not None
):
logits_info, schemata_key = self._init_logits_processor(request)
request.schemata_key = schemata_key
if self.scheduler_config.splitwise_role == "decode":
if (
hasattr(request, "prefill_end_index")
and hasattr(request, "prompt_token_ids")
and request.prefill_end_index > len(request.prompt_token_ids)
):
if hasattr(request, "output_token_ids"):
prefill_tokens.extend(request.output_token_ids)
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
@@ -657,6 +678,8 @@ class GPUModelRunner(ModelRunnerBase):
# For logits processors
self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {}
self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)
if len(multi_vision_inputs["images_lst"]) > 0:
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
@@ -2059,6 +2082,21 @@ class GPUModelRunner(ModelRunnerBase):
if self.share_inputs["step_idx"][idx] == 0:
prefill_done_idxs.append(idx)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if model_forward_batch is None:
return prefill_done_idxs
for task in model_forward_batch:
if task.task_type.value != RequestType.PREFILL.value:
continue
# in chunk prefill
if self.cache_config.enable_chunked_prefill:
if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"):
if len(task.prompt_token_ids) > task.prefill_end_index and task.idx in prefill_done_idxs:
prefill_done_idxs.remove(task.idx)
return prefill_done_idxs
if self.cache_config.enable_chunked_prefill:
if model_forward_batch is not None:
for task in model_forward_batch: