mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] ThreadPoolExecutor async fill_token_bitmask (#5083)
* ThreadPoolExecutor async fill_token_bitmask * ThreadPoolExecutor async fill_token_bitmask logging * fix test_guided_decoding * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add fill_bitmask_parallel_batch_size ENV * FD_FILL_BITMASK_BATCH fastdeploy.envs --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1417,7 +1417,6 @@ class StructuredOutputsConfig:
|
||||
# disable any whitespace for guided decoding
|
||||
self.disable_any_whitespace: bool = True
|
||||
self.logits_processors: Optional[list[str]] = None
|
||||
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key) and value != "None":
|
||||
setattr(self, key, value)
|
||||
|
||||
@@ -159,6 +159,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
|
||||
"FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")),
|
||||
"FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")),
|
||||
"FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")),
|
||||
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ from fastdeploy.model_executor.guided_decoding import (
|
||||
BaseChecker,
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
try:
|
||||
@@ -451,6 +450,7 @@ def apply_token_mask(
|
||||
logits: paddle.Tensor,
|
||||
token_bitmask: torch.Tensor,
|
||||
indices: Optional[List[int]] = None,
|
||||
is_cuda_platform: bool = True,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the token mask to the logits, modifying probabilities of invalid tokens.
|
||||
@@ -463,17 +463,16 @@ def apply_token_mask(
|
||||
Returns:
|
||||
paddle.Tensor: The modified logits tensor
|
||||
"""
|
||||
|
||||
if current_platform.is_cuda():
|
||||
skip_out_indices = len(indices) == logits.shape[0]
|
||||
if is_cuda_platform:
|
||||
dlpack = paddle.utils.dlpack.to_dlpack(logits)
|
||||
t_logits = torch.from_dlpack(dlpack)
|
||||
apply_token_bitmask_inplace(
|
||||
logits=t_logits,
|
||||
bitmask=token_bitmask.to(t_logits.device, non_blocking=True),
|
||||
indices=indices,
|
||||
indices=indices if not skip_out_indices else None,
|
||||
)
|
||||
dlpack2 = torch.utils.dlpack.to_dlpack(t_logits)
|
||||
return paddle.utils.dlpack.from_dlpack(dlpack2)
|
||||
return logits
|
||||
else:
|
||||
origin_place = logits.place
|
||||
origin_dtype = logits.dtype
|
||||
@@ -483,7 +482,7 @@ def apply_token_mask(
|
||||
apply_token_bitmask_inplace(
|
||||
logits=logits,
|
||||
bitmask=token_bitmask.to(logits.device, non_blocking=True),
|
||||
indices=indices,
|
||||
indices=indices if not skip_out_indices else None,
|
||||
)
|
||||
|
||||
return paddle.to_tensor(
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import paddle
|
||||
@@ -24,6 +25,7 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.envs import FD_FILL_BITMASK_BATCH
|
||||
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
|
||||
from fastdeploy.model_executor.layers.sample.early_stopper import (
|
||||
get_early_stopper_cls_from_stragegy,
|
||||
@@ -69,11 +71,26 @@ class GuidedDecoding:
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
self.token_bitmask = None
|
||||
self.logits_processors: List[Any] = [None] * fd_config.scheduler_config.max_num_seqs
|
||||
self.max_num_seqs: int = int(
|
||||
fd_config.scheduler_config.max_num_seqs if fd_config.scheduler_config is not None else 1
|
||||
)
|
||||
self.logits_processors: List[Any] = [None] * self.max_num_seqs
|
||||
self.reasoning_parser = None
|
||||
self._prefill_done_idxs: List[bool] = [False] * fd_config.scheduler_config.max_num_seqs
|
||||
self._prefill_done_idxs: List[bool] = [False] * self.max_num_seqs
|
||||
# for pd
|
||||
self._tokens_to_acc: List[None | List[int]] = [None] * fd_config.scheduler_config.max_num_seqs
|
||||
self._tokens_to_acc: List[None | List[int]] = [None] * self.max_num_seqs
|
||||
|
||||
self.fill_bitmask_parallel_batch_size: int = FD_FILL_BITMASK_BATCH
|
||||
max_workers = max(
|
||||
1,
|
||||
min(multiprocessing.cpu_count() // 2, int(self.max_num_seqs) / int(self.fill_bitmask_parallel_batch_size)),
|
||||
)
|
||||
self.executor_for_fillmask = ThreadPoolExecutor(max_workers=int(max_workers))
|
||||
self._fillmask_futures: List[Future] = [None] * self.max_num_seqs
|
||||
self.is_cuda_platform = current_platform.is_cuda()
|
||||
logger.info(
|
||||
f"GuidedDecoding max_num_seqs={self.max_num_seqs} fill_bitmask_parallel_batch_size={self.fill_bitmask_parallel_batch_size} is_cuda_platform={self.is_cuda_platform} max_workers={max_workers}"
|
||||
)
|
||||
|
||||
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||
self.reasoning_parser = reasoning_parser
|
||||
@@ -140,6 +157,7 @@ class GuidedDecoding:
|
||||
if isinstance(self.logits_processors[idx], Future):
|
||||
continue
|
||||
|
||||
idxs = []
|
||||
for idx, processor in enumerate(self.logits_processors):
|
||||
if processor is None or not self._prefill_done_idxs[idx]:
|
||||
continue
|
||||
@@ -156,7 +174,47 @@ class GuidedDecoding:
|
||||
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
|
||||
|
||||
if self.should_fill_bitmask(idx):
|
||||
processor.fill_token_bitmask(self.token_bitmask, idx)
|
||||
idxs.append(idx)
|
||||
self._async_batch_fill_token_bitmask(idxs)
|
||||
|
||||
def batch_fill_token_bitmask(self, batch: List[int]):
|
||||
"""
|
||||
Fills the token bitmask for a batch of logits processor indices.
|
||||
|
||||
This method is typically called asynchronously via a thread pool executor
|
||||
to parallelize the bitmask filling operation. It is important that any
|
||||
shared data structures accessed within this method (such as
|
||||
`self.token_bitmask` and `self.logits_processors`) are thread-safe or
|
||||
properly synchronized to avoid race conditions.
|
||||
|
||||
Args:
|
||||
batch (List[int]): List of indices for which to fill the token bitmask.
|
||||
"""
|
||||
for idx in batch:
|
||||
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
|
||||
|
||||
def _async_batch_fill_token_bitmask(self, idxs: List[int]):
|
||||
"""launch async fill"""
|
||||
batch: List[int] = []
|
||||
for idx in idxs:
|
||||
batch.append(idx)
|
||||
if len(batch) == self.fill_bitmask_parallel_batch_size:
|
||||
promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:])
|
||||
self._fillmask_futures[idx] = promise
|
||||
batch = []
|
||||
if batch:
|
||||
promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:])
|
||||
self._fillmask_futures[batch[-1]] = promise
|
||||
|
||||
def join_async_fillmask(self):
|
||||
"""join all async fill futures"""
|
||||
for idx, furture in enumerate(self._fillmask_futures):
|
||||
if furture is not None:
|
||||
try:
|
||||
furture.result()
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in async fillmask future at idx {idx}: {e}", exc_info=True)
|
||||
self._fillmask_futures[idx] = None
|
||||
|
||||
def accept_tokens_from_prefill_node(self, idx: int):
|
||||
"""accept prefill token, not future"""
|
||||
@@ -204,15 +262,17 @@ class GuidedDecoding:
|
||||
if self.token_bitmask is None:
|
||||
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
|
||||
|
||||
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
|
||||
# launch async fill
|
||||
self._async_batch_fill_token_bitmask([idx])
|
||||
|
||||
if len(indices) == 0:
|
||||
return logits
|
||||
self.join_async_fillmask()
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
|
||||
apply_token_mask,
|
||||
)
|
||||
|
||||
return apply_token_mask(logits, self.token_bitmask, indices=indices)
|
||||
return apply_token_mask(logits, self.token_bitmask, indices=indices, is_cuda_platform=self.is_cuda_platform)
|
||||
|
||||
def _accept_token(self, idx: int, token: int):
|
||||
"""accept token"""
|
||||
|
||||
@@ -726,7 +726,7 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_any_whitespace",
|
||||
action="store_false",
|
||||
action="store_true",
|
||||
help="Disable any whitespace for guided decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user