[Feature] ThreadPoolExecutor async fill_token_bitmask (#5083)

* ThreadPoolExecutor async fill_token_bitmask

* ThreadPoolExecutor async fill_token_bitmask logging

* fix test_guided_decoding

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add fill_bitmask_parallel_batch_size ENV

* FD_FILL_BITMASK_BATCH fastdeploy.envs

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Daci
2025-11-19 10:04:16 +08:00
committed by GitHub
parent 4a7739ec0b
commit eab8384da6
6 changed files with 76 additions and 16 deletions

View File

@@ -29,7 +29,6 @@ from fastdeploy.model_executor.guided_decoding import (
BaseChecker,
LogitsProcessorBase,
)
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger
try:
@@ -451,6 +450,7 @@ def apply_token_mask(
logits: paddle.Tensor,
token_bitmask: torch.Tensor,
indices: Optional[List[int]] = None,
is_cuda_platform: bool = True,
) -> paddle.Tensor:
"""
Apply the token mask to the logits, modifying probabilities of invalid tokens.
@@ -463,17 +463,16 @@ def apply_token_mask(
Returns:
paddle.Tensor: The modified logits tensor
"""
if current_platform.is_cuda():
skip_out_indices = len(indices) == logits.shape[0]
if is_cuda_platform:
dlpack = paddle.utils.dlpack.to_dlpack(logits)
t_logits = torch.from_dlpack(dlpack)
apply_token_bitmask_inplace(
logits=t_logits,
bitmask=token_bitmask.to(t_logits.device, non_blocking=True),
indices=indices,
indices=indices if not skip_out_indices else None,
)
dlpack2 = torch.utils.dlpack.to_dlpack(t_logits)
return paddle.utils.dlpack.from_dlpack(dlpack2)
return logits
else:
origin_place = logits.place
origin_dtype = logits.dtype
@@ -483,7 +482,7 @@ def apply_token_mask(
apply_token_bitmask_inplace(
logits=logits,
bitmask=token_bitmask.to(logits.device, non_blocking=True),
indices=indices,
indices=indices if not skip_out_indices else None,
)
return paddle.to_tensor(