diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 7ad205266..73f318db7 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1237,7 +1237,10 @@ class FDConfig: if self.max_num_batched_tokens is None: if int(envs.ENABLE_V1_KVCACHE_SCHEDULER): - self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM + if paddle.is_compiled_with_xpu(): + self.max_num_batched_tokens = self.max_model_len + else: + self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM else: if self.cache_config.enable_chunked_prefill: self.max_num_batched_tokens = 2048 diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 664b2b36d..f553ad2d2 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -19,6 +19,8 @@ from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields from typing import Any, Dict, List, Optional +import paddle + from fastdeploy import envs from fastdeploy.config import ( CacheConfig, @@ -1005,7 +1007,10 @@ class EngineArgs: if self.max_num_batched_tokens is None: if int(envs.ENABLE_V1_KVCACHE_SCHEDULER): - self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM + if paddle.is_compiled_with_xpu(): + self.max_num_batched_tokens = self.max_model_len + else: + self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM else: if self.enable_chunked_prefill: self.max_num_batched_tokens = 2048 diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 5ea7f094a..339f18f32 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -363,7 +363,9 @@ class ResourceManagerV1(ResourceManager): while self.waiting and token_budget > 0: if len(self.running) == self.max_num_seqs: break - if self.config.model_config.enable_mm and self.exist_prefill(scheduled_reqs): + if (self.config.model_config.enable_mm or paddle.is_compiled_with_xpu()) and self.exist_prefill( + scheduled_reqs + ): break request = self.waiting[0] if request.status == RequestStatus.WAITING: diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index cee71415b..09ec0ee1a 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -383,6 +383,7 @@ class XPUModelRunner(ModelRunnerBase): req_len = len(req_dicts) has_prefill_task = False + has_decode_task = False for i in range(req_len): request = req_dicts[i] idx = request.idx @@ -392,6 +393,9 @@ class XPUModelRunner(ModelRunnerBase): prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index input_ids = request.prompt_token_ids + request.output_token_ids + logger.debug( + f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}" + ) self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( input_ids[prefill_start_index:prefill_end_index] ) @@ -401,6 +405,8 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) + if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode + has_decode_task = True self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length @@ -474,7 +480,7 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( request.get("stop_token_ids"), dtype="int64" ) - if has_prefill_task: + if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True def process_prefill_inputs(self, req_dicts: List[Request]):