diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index 1174bbafc..be3696320 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -134,7 +134,6 @@ jobs: -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ -e "FLASK_PORT=${FLASK_PORT}" \ - -e "FD_FORCE_CHUNKED_PREFILL=1" \ -v "${MODEL_CACHE_DIR}:/MODELDATA" \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ diff --git a/fastdeploy/config.py b/fastdeploy/config.py index e4182e6c9..768cf6af4 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1233,23 +1233,14 @@ class FDConfig: self.paddle_commit_id = paddle.version.commit - if self.cache_config.enable_chunked_prefill: - self.force_chunked_prefill = int(envs.FD_FORCE_CHUNKED_PREFILL) - if ( - self.speculative_config is not None - and self.speculative_config.method in ["mtp"] - and not self.force_chunked_prefill - ): - self.cache_config.enable_chunked_prefill = False - if self.max_num_batched_tokens is None: - if self.cache_config.enable_chunked_prefill: - self.max_num_batched_tokens = 2048 + if int(envs.ENABLE_V1_KVCACHE_SCHEDULER): + self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM else: - if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): - self.max_num_batched_tokens = self.max_model_len + if self.cache_config.enable_chunked_prefill: + self.max_num_batched_tokens = 2048 else: - self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM + self.max_num_batched_tokens = self.max_model_len if self.long_prefill_token_threshold == 0: self.long_prefill_token_threshold = int(self.max_model_len * 0.04) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 10ed83525..b2be0953d 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -15,11 +15,11 @@ """ import json -import os from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields from typing import Any, Dict, List, Optional +from fastdeploy import envs from fastdeploy.config import ( CacheConfig, EarlyStopConfig, @@ -243,7 +243,7 @@ class EngineArgs: Ports for rdma communication. """ - enable_chunked_prefill: bool = True + enable_chunked_prefill: bool = False """ Flag to enable chunked prefilling. """ @@ -981,14 +981,29 @@ class EngineArgs: if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"): self.tensor_parallel_size = model_cfg.tensor_parallel_size + + speculative_cfg = self.create_speculative_config() + if not self.enable_chunked_prefill: + if ( + current_platform.is_cuda() + and self.splitwise_role == "mixed" + and (speculative_cfg is None or speculative_cfg.method not in ["mtp"]) + ): + # default enable chunked prefill + self.enable_chunked_prefill = True + + self.disable_chunked_prefill = int(envs.FD_DISABLE_CHUNKED_PREFILL) + if self.disable_chunked_prefill: + self.enable_chunked_prefill = False + if self.max_num_batched_tokens is None: - if self.enable_chunked_prefill: - self.max_num_batched_tokens = 2048 + if int(envs.ENABLE_V1_KVCACHE_SCHEDULER): + self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM else: - if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")): - self.max_num_batched_tokens = self.max_model_len + if self.enable_chunked_prefill: + self.max_num_batched_tokens = 2048 else: - self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM + self.max_num_batched_tokens = self.max_model_len all_dict = asdict(self) all_dict["model_cfg"] = model_cfg @@ -996,7 +1011,6 @@ class EngineArgs: load_cfg = LoadConfig(all_dict) parallel_cfg = ParallelConfig(all_dict) scheduler_cfg = self.create_scheduler_config() - speculative_cfg = self.create_speculative_config() graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) moba_attention_config = self.create_moba_attention_config() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 9ee6656e3..f6515f061 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -93,8 +93,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # enable multi api server "FD_ENABLE_MULTI_API_SERVER": lambda: bool(int(os.getenv("FD_ENABLE_MULTI_API_SERVER", "0"))), "FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))), - # force enable chunked prefill - "FD_FORCE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_FORCE_CHUNKED_PREFILL", "0"))), + # force disable default chunked prefill + "FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))), }