diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 5ec3df934..afe37f076 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1417,7 +1417,6 @@ class StructuredOutputsConfig: # disable any whitespace for guided decoding self.disable_any_whitespace: bool = True self.logits_processors: Optional[list[str]] = None - for key, value in args.items(): if hasattr(self, key) and value != "None": setattr(self, key, value) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 9d6b597b1..23030d6a8 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -159,6 +159,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")), "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), + "FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")), "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), } diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py index f7c246edd..e0f195e47 100644 --- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py +++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py @@ -29,7 +29,6 @@ from fastdeploy.model_executor.guided_decoding import ( BaseChecker, LogitsProcessorBase, ) -from fastdeploy.platforms import current_platform from fastdeploy.utils import llm_logger try: @@ -451,6 +450,7 @@ def apply_token_mask( logits: paddle.Tensor, token_bitmask: torch.Tensor, indices: Optional[List[int]] = None, + is_cuda_platform: bool = True, ) -> paddle.Tensor: """ Apply the token mask to the logits, modifying probabilities of invalid tokens. @@ -463,17 +463,16 @@ def apply_token_mask( Returns: paddle.Tensor: The modified logits tensor """ - - if current_platform.is_cuda(): + skip_out_indices = len(indices) == logits.shape[0] + if is_cuda_platform: dlpack = paddle.utils.dlpack.to_dlpack(logits) t_logits = torch.from_dlpack(dlpack) apply_token_bitmask_inplace( logits=t_logits, bitmask=token_bitmask.to(t_logits.device, non_blocking=True), - indices=indices, + indices=indices if not skip_out_indices else None, ) - dlpack2 = torch.utils.dlpack.to_dlpack(t_logits) - return paddle.utils.dlpack.from_dlpack(dlpack2) + return logits else: origin_place = logits.place origin_dtype = logits.dtype @@ -483,7 +482,7 @@ def apply_token_mask( apply_token_bitmask_inplace( logits=logits, bitmask=token_bitmask.to(logits.device, non_blocking=True), - indices=indices, + indices=indices if not skip_out_indices else None, ) return paddle.to_tensor( diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index b03454268..0e931bb43 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -14,8 +14,9 @@ # limitations under the License. """ +import multiprocessing import time -from concurrent.futures import Future +from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, List, Optional import paddle @@ -24,6 +25,7 @@ from paddle import nn from paddleformers.utils.log import logger from fastdeploy.config import FDConfig +from fastdeploy.envs import FD_FILL_BITMASK_BATCH from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase from fastdeploy.model_executor.layers.sample.early_stopper import ( get_early_stopper_cls_from_stragegy, @@ -69,11 +71,26 @@ class GuidedDecoding: def __init__(self, fd_config: FDConfig): self.token_bitmask = None - self.logits_processors: List[Any] = [None] * fd_config.scheduler_config.max_num_seqs + self.max_num_seqs: int = int( + fd_config.scheduler_config.max_num_seqs if fd_config.scheduler_config is not None else 1 + ) + self.logits_processors: List[Any] = [None] * self.max_num_seqs self.reasoning_parser = None - self._prefill_done_idxs: List[bool] = [False] * fd_config.scheduler_config.max_num_seqs + self._prefill_done_idxs: List[bool] = [False] * self.max_num_seqs # for pd - self._tokens_to_acc: List[None | List[int]] = [None] * fd_config.scheduler_config.max_num_seqs + self._tokens_to_acc: List[None | List[int]] = [None] * self.max_num_seqs + + self.fill_bitmask_parallel_batch_size: int = FD_FILL_BITMASK_BATCH + max_workers = max( + 1, + min(multiprocessing.cpu_count() // 2, int(self.max_num_seqs) / int(self.fill_bitmask_parallel_batch_size)), + ) + self.executor_for_fillmask = ThreadPoolExecutor(max_workers=int(max_workers)) + self._fillmask_futures: List[Future] = [None] * self.max_num_seqs + self.is_cuda_platform = current_platform.is_cuda() + logger.info( + f"GuidedDecoding max_num_seqs={self.max_num_seqs} fill_bitmask_parallel_batch_size={self.fill_bitmask_parallel_batch_size} is_cuda_platform={self.is_cuda_platform} max_workers={max_workers}" + ) def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): self.reasoning_parser = reasoning_parser @@ -140,6 +157,7 @@ class GuidedDecoding: if isinstance(self.logits_processors[idx], Future): continue + idxs = [] for idx, processor in enumerate(self.logits_processors): if processor is None or not self._prefill_done_idxs[idx]: continue @@ -156,7 +174,47 @@ class GuidedDecoding: self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask() if self.should_fill_bitmask(idx): - processor.fill_token_bitmask(self.token_bitmask, idx) + idxs.append(idx) + self._async_batch_fill_token_bitmask(idxs) + + def batch_fill_token_bitmask(self, batch: List[int]): + """ + Fills the token bitmask for a batch of logits processor indices. + + This method is typically called asynchronously via a thread pool executor + to parallelize the bitmask filling operation. It is important that any + shared data structures accessed within this method (such as + `self.token_bitmask` and `self.logits_processors`) are thread-safe or + properly synchronized to avoid race conditions. + + Args: + batch (List[int]): List of indices for which to fill the token bitmask. + """ + for idx in batch: + self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx) + + def _async_batch_fill_token_bitmask(self, idxs: List[int]): + """launch async fill""" + batch: List[int] = [] + for idx in idxs: + batch.append(idx) + if len(batch) == self.fill_bitmask_parallel_batch_size: + promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:]) + self._fillmask_futures[idx] = promise + batch = [] + if batch: + promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:]) + self._fillmask_futures[batch[-1]] = promise + + def join_async_fillmask(self): + """join all async fill futures""" + for idx, furture in enumerate(self._fillmask_futures): + if furture is not None: + try: + furture.result() + except Exception as e: + logger.error(f"Exception in async fillmask future at idx {idx}: {e}", exc_info=True) + self._fillmask_futures[idx] = None def accept_tokens_from_prefill_node(self, idx: int): """accept prefill token, not future""" @@ -204,15 +262,17 @@ class GuidedDecoding: if self.token_bitmask is None: self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask() - self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx) + # launch async fill + self._async_batch_fill_token_bitmask([idx]) if len(indices) == 0: return logits + self.join_async_fillmask() from fastdeploy.model_executor.guided_decoding.xgrammar_backend import ( apply_token_mask, ) - return apply_token_mask(logits, self.token_bitmask, indices=indices) + return apply_token_mask(logits, self.token_bitmask, indices=indices, is_cuda_platform=self.is_cuda_platform) def _accept_token(self, idx: int, token: int): """accept token""" diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 5016c176e..b30db63a7 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -726,7 +726,7 @@ def parse_args(): ) parser.add_argument( "--disable_any_whitespace", - action="store_false", + action="store_true", help="Disable any whitespace for guided decoding.", ) parser.add_argument( diff --git a/tests/layers/test_guided_decoding.py b/tests/layers/test_guided_decoding.py index 1f71fcb34..964ad1dc0 100644 --- a/tests/layers/test_guided_decoding.py +++ b/tests/layers/test_guided_decoding.py @@ -26,6 +26,7 @@ class TestGuidedDecoding(unittest.TestCase): """Setup for each test case.""" # 创建一个基本的FDConfig对象 self.fd_config = Mock() + self.fd_config.scheduler_config = Mock() self.fd_config.scheduler_config.max_num_seqs = 5 # 创建GuidedDecoding对象