[Optimization] xgrammar async compile, multi thread, speed up (#4835)

* xgrammar async compile, multi thread, speed up

* fix test_sampler.py & pre-commit err

* add redis version check && fix request.llm_engine_recv_req_timestamp

* xgrammar prefill & decode & v0

* fix test_gpu_prompt_logprobs.py

* add test_guided_decoding.py

* Update fastdeploy/scheduler/splitwise_scheduler.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix torch xgrammar unittest env

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Daci
2025-11-14 18:05:26 +08:00
committed by GitHub
parent b925533051
commit 5fc12eddfe
11 changed files with 810 additions and 373 deletions

View File

@@ -14,13 +14,14 @@
# limitations under the License.
"""
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
import time
from concurrent.futures import Future
from typing import Any, List, Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
@@ -66,140 +67,187 @@ class GuidedDecoding:
processor for guided decoding.
"""
def __init__(self):
self.async_step = None
def __init__(self, fd_config: FDConfig):
self.token_bitmask = None
self.logits_processor: Dict[int, Optional[Any]] = dict()
self.executor = ThreadPoolExecutor()
self.logits_lock = threading.Lock()
self.logits_processors: List[Any] = [None] * fd_config.scheduler_config.max_num_seqs
self.reasoning_parser = None
self._prefill_done_idxs: List[bool] = [False] * fd_config.scheduler_config.max_num_seqs
# for pd
self._tokens_to_acc: List[None | List[int]] = [None] * fd_config.scheduler_config.max_num_seqs
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
self.reasoning_parser = reasoning_parser
def add_logits_processor(
self,
ids: int,
idx: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""add logits processor to GuidedDecoding"""
with self.logits_lock:
if future is None:
if ids in self.logits_processor:
del self.logits_processor[ids]
return
"""add logits processor to SamplerProcessor"""
self._prefill_done_idxs[idx] = False
if isinstance(future, LogitsProcessorBase):
self.logits_processor[ids] = future
for token in prefill_tokens:
self.logits_processor[ids].accept_token(token)
elif future.done():
self.logits_processor[ids] = future.result()
for token in prefill_tokens:
self.logits_processor[ids].accept_token(token)
else:
self.logits_processor[ids] = [future, prefill_tokens]
def update_vocab_mask(self, skip_idx_list: List[int] = []):
"""update vocab mask. (cpu-heavy operation)"""
if len(self.logits_processor) == 0:
if future is None:
# normal request without guided_backend
self.logits_processors[idx] = None
return
with self.logits_lock:
for idx, processor in self.logits_processor.items():
if processor is None:
del self.logits_processor[idx]
continue
if len(prefill_tokens) != 0:
# first_token from prefill node
self._prefill_done_idxs[idx] = True
if not isinstance(processor, LogitsProcessorBase):
future, prefill_tokens = self.logits_processor[idx]
self.logits_processor[idx] = future.result()
for token in prefill_tokens:
self.logits_processor[idx].accept_token(token)
if future.done():
# cached xgrammar
self.logits_processors[idx] = future.result()
for token in prefill_tokens:
self._accept_token(idx, token)
else:
# async
self.logits_processors[idx] = future
self._tokens_to_acc[idx] = prefill_tokens
available_processors = None
for processor in self.logits_processor.values():
if processor.is_terminated():
continue
available_processors = processor
if available_processors is None:
return
def should_fill_bitmask(self, idx: int) -> bool:
"""
Determines whether to fill a bitmask for the logits processor at the given index.
# allocate token bitmask
self.token_bitmask = available_processors.allocate_token_bitmask()
Args:
idx (int): The index of the logits processor to check
with self.logits_lock:
# fill token bitmask
for idx, processor in self.logits_processor.items():
if processor.is_terminated() or idx in skip_idx_list:
continue
Returns:
bool: True if the idx request bitmask should be filled
"""
if self.reasoning_parser is not None:
if self.logits_processors[idx].enable_reasoning: # <think> guided
return True
if not self.logits_processors[idx].reasoning_ended:
return False
return True
def reset_processor(self, idx: int):
"""reset idx"""
self._prefill_done_idxs[idx] = False
self.logits_processors[idx] = None
def update_vocab_mask(self, prefill_done_idxs: List[int] = []):
"""update vocab mask. (cpu-heavy operation)"""
for idx in prefill_done_idxs:
if self.logits_processors[idx] is None:
continue
assert not self._prefill_done_idxs[idx]
self._prefill_done_idxs[idx] = True
if isinstance(self.logits_processors[idx], Future):
continue
for idx, processor in enumerate(self.logits_processors):
if processor is None or not self._prefill_done_idxs[idx]:
continue
# skip, join at apply_token_mask
if isinstance(processor, Future):
continue
if processor.is_terminated:
self.reset_processor(idx)
continue
self.accept_tokens_from_prefill_node(idx)
if self.token_bitmask is None:
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
if self.should_fill_bitmask(idx):
processor.fill_token_bitmask(self.token_bitmask, idx)
def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []):
"""apply token mask to logits"""
if len(self.logits_processor) == 0 or self.token_bitmask is None:
return logits
def accept_tokens_from_prefill_node(self, idx: int):
"""accept prefill token, not future"""
if self._tokens_to_acc[idx] is not None:
# accept token from prefill node first
for token in self._tokens_to_acc[idx]:
self._accept_token(idx, token)
self._tokens_to_acc[idx] = None
# self.async_step.result()
available_processors = None
with self.logits_lock:
for processor in self.logits_processor.values():
if processor.is_terminated():
continue
available_processors = processor
if available_processors is None:
return logits
def apply_token_mask(self, logits: paddle.Tensor, prefill_done_idxs: List[int] = []):
"""apply token mask to logits"""
indices = []
for idx, processor in self.logits_processor.items():
if processor is None or idx in skip_idx_list:
for idx, processor in enumerate(self.logits_processors):
if processor is None or not self._prefill_done_idxs[idx]:
continue
if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended:
indices.append(idx)
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices)
# compiled done, check idx should fill, fill_token_bitmask done in preprocess
if not isinstance(processor, Future):
if self.should_fill_bitmask(idx):
indices.append(idx)
continue
# is Future, processor async compiled not ready, need join and wait
ts = time.time()
wait = False
if not processor.done():
wait = True
self.logits_processors[idx] = processor.result()
if wait:
logger.debug(f"[{idx} join async compile xgrammar, time_cost:{time.time() - ts}]")
self.accept_tokens_from_prefill_node(idx)
# Possible optimization: Extract 'think' content validation from logits_processors,
# allowing join operations to complete immediately after 'think' terminates.
# Furthermore, the current idx could be skipped, with compilation overhead
# estimated at only a few milliseconds.
# check idx for fill_token_mask
if not self.should_fill_bitmask(idx):
continue
indices.append(idx)
if self.token_bitmask is None:
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
if len(indices) == 0:
return logits
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
apply_token_mask,
)
return apply_token_mask(logits, self.token_bitmask, indices=indices)
def _accept_token(self, idx: int, token: int):
"""accept token"""
if idx not in self.logits_processor:
raise ValueError(f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}")
if self.logits_processor[idx].is_terminated():
return
if self.reasoning_parser is not None:
if not self.logits_processors[idx].enable_reasoning:
if not self.logits_processors[idx].reasoning_ended:
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
self.logits_processors[idx].reasoning_ended = reasoning_ended
return
if (
self.reasoning_parser is not None
and self.logits_processor[idx].enable_reasoning
and not self.logits_processor[idx].reasoning_ended
):
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
self.logits_processor[idx].reasoning_ended = reasoning_ended
return
if not self.logits_processors[idx].accept_token(token) or self.logits_processors[idx].is_terminated:
self.reset_processor(idx)
self.logits_processor[idx].accept_token(token)
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
def update_output_tokens(self, next_tokens: paddle.Tensor):
"""update output tokens"""
if len(self.logits_processor) == 0:
if len(self.logits_processors) == 0:
return
token_ids = next_tokens.numpy().tolist()
with self.logits_lock:
for idx in self.logits_processor.keys():
token = token_ids[idx][0]
if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list:
continue
for idx, processor in enumerate(self.logits_processors):
if not self._prefill_done_idxs[idx] or processor is None:
continue
if idx >= len(token_ids):
continue
token = token_ids[idx][0]
if token < 0:
self.reset_processor(idx)
continue
logger.debug(f"[{idx}]accept token{token}")
self._accept_token(idx, token)
self._accept_token(idx, token)
def pre_process(self, skip_idx_list: List[int] = []):
def pre_process(self, prefill_done_idxs: List[int] = []):
"""pre process before running"""
# create async operation for guided decoding
# TODO: support async
self.update_vocab_mask(skip_idx_list)
# self.async_step = self.executor.submit(self.update_vocab_mask)
self.update_vocab_mask(prefill_done_idxs)
class Sampler(nn.Layer):
@@ -224,7 +272,7 @@ class Sampler(nn.Layer):
else:
raise NotImplementedError
self.guided_decoding = GuidedDecoding()
self.guided_decoding = GuidedDecoding(fd_config)
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
@@ -240,17 +288,19 @@ class Sampler(nn.Layer):
"""set reasoning parser"""
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
def apply_logits_processor(
self, ids: int, future: Future[LogitsProcessorBase] = None, prefill_tokens: List[int] = []
):
"""apply logits processor to sampler"""
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
def pre_process(self, skip_idx_list: List[int] = []):
def pre_process(self, prefill_done_idxs: List[int] = []):
"""pre process before running"""
self.guided_decoding.pre_process(skip_idx_list)
self.guided_decoding.pre_process(prefill_done_idxs)
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
def post_process(self, next_tokens: paddle.Tensor):
"""post process after running"""
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
self.guided_decoding.update_output_tokens(next_tokens)
def compute_logprobs(
self,
@@ -343,10 +393,10 @@ class Sampler(nn.Layer):
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
skip_idx_list: List[int] = [],
p_done_idxs: List[int] = [],
) -> SamplerOutput:
""" """
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
logits = self.guided_decoding.apply_token_mask(logits, p_done_idxs)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None: