diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index c254aaa1a..835d3eb4d 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -15,10 +15,12 @@ """ import json +import os from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields from typing import Any, Dict, List, Optional -import os + +import paddle from fastdeploy.config import ( CacheConfig, @@ -866,10 +868,13 @@ class EngineArgs: if self.enable_chunked_prefill: self.max_num_batched_tokens = 2048 else: - if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): + if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): self.max_num_batched_tokens = self.max_model_len else: - self.max_num_batched_tokens = 8192 + if paddle.is_compiled_with_xpu(): + self.max_num_batched_tokens = self.max_model_len + else: + self.max_num_batched_tokens = 8192 all_dict = asdict(self) all_dict["model_cfg"] = model_cfg diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index fb57884bf..f6303d7b3 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -236,10 +236,13 @@ class Config: if self.cache_config.enable_chunked_prefill: self.max_num_batched_tokens = 2048 else: - if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): + if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): self.max_num_batched_tokens = self.max_model_len else: - self.max_num_batched_tokens = 8192 + if paddle.is_compiled_with_xpu(): + self.max_num_batched_tokens = self.max_model_len + else: + self.max_num_batched_tokens = 8192 if self.long_prefill_token_threshold == 0: self.long_prefill_token_threshold = int(self.max_model_len * 0.04) @@ -287,7 +290,7 @@ class Config: ) if not self.cache_config.enable_chunked_prefill: - if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')): + if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): assert self.max_num_batched_tokens >= self.max_model_len, ( f"max_num_batched_tokens: {self.max_num_batched_tokens} " f"should be larger than or equal to max_model_len: {self.max_model_len}" diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 4aecabcd5..ba0197a90 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -289,7 +289,7 @@ class ResourceManagerV1(ResourceManager): while self.waiting and token_budget > 0: if len(self.running) == self.max_num_seqs: break - if self.config.enable_mm and self.exist_prefill(scheduled_reqs): + if (self.config.enable_mm or paddle.is_compiled_with_xpu()) and self.exist_prefill(scheduled_reqs): break request = self.waiting[0] if request.status == RequestStatus.WAITING: diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index a5558ac47..3c76b9a2c 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -383,15 +383,18 @@ class XPUModelRunner(ModelRunnerBase): req_len = len(req_dicts) has_prefill_task = False + has_decode_task = False for i in range(req_len): request = req_dicts[i] idx = request.idx if request.task_type.value == RequestType.PREFILL.value: # prefill task - logger.debug(f"Handle prefill request {request} at idx {idx}") prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index input_ids = request.prompt_token_ids + request.output_token_ids + logger.debug( + f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}" + ) self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( input_ids[prefill_start_index:prefill_end_index] ) @@ -420,6 +423,8 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) + if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode + has_decode_task = True continue else: # preempted task logger.debug(f"Handle preempted request {request} at idx {idx}") @@ -460,7 +465,7 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( request.get("stop_token_ids"), dtype="int64" ) - if has_prefill_task: + if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True def process_prefill_inputs(self, req_dicts: List[Request]):