[Feature] ThreadPoolExecutor async fill_token_bitmask (#5083)

* ThreadPoolExecutor async fill_token_bitmask

* ThreadPoolExecutor async fill_token_bitmask logging

* fix test_guided_decoding

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add fill_bitmask_parallel_batch_size ENV

* FD_FILL_BITMASK_BATCH fastdeploy.envs

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Daci
2025-11-19 10:04:16 +08:00
committed by GitHub
parent 4a7739ec0b
commit eab8384da6
6 changed files with 76 additions and 16 deletions

View File

@@ -29,7 +29,6 @@ from fastdeploy.model_executor.guided_decoding import (
BaseChecker,
LogitsProcessorBase,
)
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger
try:
@@ -451,6 +450,7 @@ def apply_token_mask(
logits: paddle.Tensor,
token_bitmask: torch.Tensor,
indices: Optional[List[int]] = None,
is_cuda_platform: bool = True,
) -> paddle.Tensor:
"""
Apply the token mask to the logits, modifying probabilities of invalid tokens.
@@ -463,17 +463,16 @@ def apply_token_mask(
Returns:
paddle.Tensor: The modified logits tensor
"""
if current_platform.is_cuda():
skip_out_indices = len(indices) == logits.shape[0]
if is_cuda_platform:
dlpack = paddle.utils.dlpack.to_dlpack(logits)
t_logits = torch.from_dlpack(dlpack)
apply_token_bitmask_inplace(
logits=t_logits,
bitmask=token_bitmask.to(t_logits.device, non_blocking=True),
indices=indices,
indices=indices if not skip_out_indices else None,
)
dlpack2 = torch.utils.dlpack.to_dlpack(t_logits)
return paddle.utils.dlpack.from_dlpack(dlpack2)
return logits
else:
origin_place = logits.place
origin_dtype = logits.dtype
@@ -483,7 +482,7 @@ def apply_token_mask(
apply_token_bitmask_inplace(
logits=logits,
bitmask=token_bitmask.to(logits.device, non_blocking=True),
indices=indices,
indices=indices if not skip_out_indices else None,
)
return paddle.to_tensor(

View File

@@ -14,8 +14,9 @@
# limitations under the License.
"""
import multiprocessing
import time
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, List, Optional
import paddle
@@ -24,6 +25,7 @@ from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.envs import FD_FILL_BITMASK_BATCH
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
@@ -69,11 +71,26 @@ class GuidedDecoding:
def __init__(self, fd_config: FDConfig):
self.token_bitmask = None
self.logits_processors: List[Any] = [None] * fd_config.scheduler_config.max_num_seqs
self.max_num_seqs: int = int(
fd_config.scheduler_config.max_num_seqs if fd_config.scheduler_config is not None else 1
)
self.logits_processors: List[Any] = [None] * self.max_num_seqs
self.reasoning_parser = None
self._prefill_done_idxs: List[bool] = [False] * fd_config.scheduler_config.max_num_seqs
self._prefill_done_idxs: List[bool] = [False] * self.max_num_seqs
# for pd
self._tokens_to_acc: List[None | List[int]] = [None] * fd_config.scheduler_config.max_num_seqs
self._tokens_to_acc: List[None | List[int]] = [None] * self.max_num_seqs
self.fill_bitmask_parallel_batch_size: int = FD_FILL_BITMASK_BATCH
max_workers = max(
1,
min(multiprocessing.cpu_count() // 2, int(self.max_num_seqs) / int(self.fill_bitmask_parallel_batch_size)),
)
self.executor_for_fillmask = ThreadPoolExecutor(max_workers=int(max_workers))
self._fillmask_futures: List[Future] = [None] * self.max_num_seqs
self.is_cuda_platform = current_platform.is_cuda()
logger.info(
f"GuidedDecoding max_num_seqs={self.max_num_seqs} fill_bitmask_parallel_batch_size={self.fill_bitmask_parallel_batch_size} is_cuda_platform={self.is_cuda_platform} max_workers={max_workers}"
)
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
self.reasoning_parser = reasoning_parser
@@ -140,6 +157,7 @@ class GuidedDecoding:
if isinstance(self.logits_processors[idx], Future):
continue
idxs = []
for idx, processor in enumerate(self.logits_processors):
if processor is None or not self._prefill_done_idxs[idx]:
continue
@@ -156,7 +174,47 @@ class GuidedDecoding:
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
if self.should_fill_bitmask(idx):
processor.fill_token_bitmask(self.token_bitmask, idx)
idxs.append(idx)
self._async_batch_fill_token_bitmask(idxs)
def batch_fill_token_bitmask(self, batch: List[int]):
"""
Fills the token bitmask for a batch of logits processor indices.
This method is typically called asynchronously via a thread pool executor
to parallelize the bitmask filling operation. It is important that any
shared data structures accessed within this method (such as
`self.token_bitmask` and `self.logits_processors`) are thread-safe or
properly synchronized to avoid race conditions.
Args:
batch (List[int]): List of indices for which to fill the token bitmask.
"""
for idx in batch:
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
def _async_batch_fill_token_bitmask(self, idxs: List[int]):
"""launch async fill"""
batch: List[int] = []
for idx in idxs:
batch.append(idx)
if len(batch) == self.fill_bitmask_parallel_batch_size:
promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:])
self._fillmask_futures[idx] = promise
batch = []
if batch:
promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:])
self._fillmask_futures[batch[-1]] = promise
def join_async_fillmask(self):
"""join all async fill futures"""
for idx, furture in enumerate(self._fillmask_futures):
if furture is not None:
try:
furture.result()
except Exception as e:
logger.error(f"Exception in async fillmask future at idx {idx}: {e}", exc_info=True)
self._fillmask_futures[idx] = None
def accept_tokens_from_prefill_node(self, idx: int):
"""accept prefill token, not future"""
@@ -204,15 +262,17 @@ class GuidedDecoding:
if self.token_bitmask is None:
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
# launch async fill
self._async_batch_fill_token_bitmask([idx])
if len(indices) == 0:
return logits
self.join_async_fillmask()
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
apply_token_mask,
)
return apply_token_mask(logits, self.token_bitmask, indices=indices)
return apply_token_mask(logits, self.token_bitmask, indices=indices, is_cuda_platform=self.is_cuda_platform)
def _accept_token(self, idx: int, token: int):
"""accept token"""