diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 02df10328..3326e8321 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -854,6 +854,11 @@ class Config: self.max_num_batched_tokens >= self.max_model_len ), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \ f"should be larger than or equal to max_model_len: {self.max_model_len}" + else: + assert ( + self.max_num_batched_tokens >= self.cache_config.block_size + ), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \ + f"should be larger than or equal to block_size: {self.cache_config.block_size}" if self.max_num_partial_prefills > 1: assert (self.cache_config.enable_chunked_prefill is True), \ diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 6a5d30d21..414a7b209 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -134,6 +134,7 @@ class LLMEngine(object): for idx in range(1, self.cfg.max_num_partial_prefills + 1): self.partial_chunked_tokens[idx] = (self.cfg.max_num_batched_tokens // idx) \ // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size + self.partial_chunked_tokens[idx] = max(1, self.partial_chunked_tokens[idx]) self._finalizer = weakref.finalize(self, self._exit_sub_services) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index e273d0714..6123a37b4 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -394,6 +394,18 @@ class PaddleDisWorkerProc(): time.sleep(0.01) num_blocks_global = self.get_profile_block_num_signal.value.min( ).item() + + if num_blocks_global < 0: + logger.error( + f"The total number of blocks cannot be less than zero." + f"Please increase gpu_memory_utilization" + f"Or decrease max_num_batched_tokens(max model length) ") + raise ValueError( + f"The total number of blocks cannot be less than zero." + f"Please increase gpu_memory_utilization" + f"Or decrease max_num_batched_tokens(max model length) ") + + self.get_profile_block_num_signal.value[ self.local_rank] = num_blocks_global else: