[XPU] Fixed the issue of performance degradation caused by enabling ENABLE_V1_KVCACHE_SCHEDULER (#3393)

* fix v1 schedule oom bug

* fix v1 schedule oom bug
This commit is contained in:
yinwei
2025-08-14 17:41:40 +08:00
committed by GitHub
parent 28918702c2
commit 101605869c
4 changed files with 22 additions and 9 deletions

View File

@@ -15,10 +15,12 @@
""" """
import json import json
import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import os
import paddle
from fastdeploy.config import ( from fastdeploy.config import (
CacheConfig, CacheConfig,
@@ -866,7 +868,10 @@ class EngineArgs:
if self.enable_chunked_prefill: if self.enable_chunked_prefill:
self.max_num_batched_tokens = 2048 self.max_num_batched_tokens = 2048
else: else:
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
self.max_num_batched_tokens = self.max_model_len
else:
if paddle.is_compiled_with_xpu():
self.max_num_batched_tokens = self.max_model_len self.max_num_batched_tokens = self.max_model_len
else: else:
self.max_num_batched_tokens = 8192 self.max_num_batched_tokens = 8192

View File

@@ -236,7 +236,10 @@ class Config:
if self.cache_config.enable_chunked_prefill: if self.cache_config.enable_chunked_prefill:
self.max_num_batched_tokens = 2048 self.max_num_batched_tokens = 2048
else: else:
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
self.max_num_batched_tokens = self.max_model_len
else:
if paddle.is_compiled_with_xpu():
self.max_num_batched_tokens = self.max_model_len self.max_num_batched_tokens = self.max_model_len
else: else:
self.max_num_batched_tokens = 8192 self.max_num_batched_tokens = 8192
@@ -287,7 +290,7 @@ class Config:
) )
if not self.cache_config.enable_chunked_prefill: if not self.cache_config.enable_chunked_prefill:
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
assert self.max_num_batched_tokens >= self.max_model_len, ( assert self.max_num_batched_tokens >= self.max_model_len, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} " f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_model_len: {self.max_model_len}" f"should be larger than or equal to max_model_len: {self.max_model_len}"

View File

@@ -289,7 +289,7 @@ class ResourceManagerV1(ResourceManager):
while self.waiting and token_budget > 0: while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_seqs: if len(self.running) == self.max_num_seqs:
break break
if self.config.enable_mm and self.exist_prefill(scheduled_reqs): if (self.config.enable_mm or paddle.is_compiled_with_xpu()) and self.exist_prefill(scheduled_reqs):
break break
request = self.waiting[0] request = self.waiting[0]
if request.status == RequestStatus.WAITING: if request.status == RequestStatus.WAITING:

View File

@@ -383,15 +383,18 @@ class XPUModelRunner(ModelRunnerBase):
req_len = len(req_dicts) req_len = len(req_dicts)
has_prefill_task = False has_prefill_task = False
has_decode_task = False
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task if request.task_type.value == RequestType.PREFILL.value: # prefill task
logger.debug(f"Handle prefill request {request} at idx {idx}")
prefill_start_index = request.prefill_start_index prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index length = prefill_end_index - prefill_start_index
input_ids = request.prompt_token_ids + request.output_token_ids input_ids = request.prompt_token_ids + request.output_token_ids
logger.debug(
f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}"
)
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
input_ids[prefill_start_index:prefill_end_index] input_ids[prefill_start_index:prefill_end_index]
) )
@@ -420,6 +423,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32" request.block_tables, dtype="int32"
) )
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
has_decode_task = True
continue continue
else: # preempted task else: # preempted task
logger.debug(f"Handle preempted request {request} at idx {idx}") logger.debug(f"Handle preempted request {request} at idx {idx}")
@@ -460,7 +465,7 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
request.get("stop_token_ids"), dtype="int64" request.get("stop_token_ids"), dtype="int64"
) )
if has_prefill_task: if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
def process_prefill_inputs(self, req_dicts: List[Request]): def process_prefill_inputs(self, req_dicts: List[Request]):