mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Optimization] xgrammar async compile, multi thread, speed up (#4835)
* xgrammar async compile, multi thread, speed up * fix test_sampler.py & pre-commit err * add redis version check && fix request.llm_engine_recv_req_timestamp * xgrammar prefill & decode & v0 * fix test_gpu_prompt_logprobs.py * add test_guided_decoding.py * Update fastdeploy/scheduler/splitwise_scheduler.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix torch xgrammar unittest env --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -14,13 +14,14 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
|
||||
@@ -66,140 +67,187 @@ class GuidedDecoding:
|
||||
processor for guided decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.async_step = None
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
self.token_bitmask = None
|
||||
self.logits_processor: Dict[int, Optional[Any]] = dict()
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.logits_lock = threading.Lock()
|
||||
self.logits_processors: List[Any] = [None] * fd_config.scheduler_config.max_num_seqs
|
||||
self.reasoning_parser = None
|
||||
self._prefill_done_idxs: List[bool] = [False] * fd_config.scheduler_config.max_num_seqs
|
||||
# for pd
|
||||
self._tokens_to_acc: List[None | List[int]] = [None] * fd_config.scheduler_config.max_num_seqs
|
||||
|
||||
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||
self.reasoning_parser = reasoning_parser
|
||||
|
||||
def add_logits_processor(
|
||||
self,
|
||||
ids: int,
|
||||
idx: int,
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = [],
|
||||
):
|
||||
"""add logits processor to GuidedDecoding"""
|
||||
with self.logits_lock:
|
||||
if future is None:
|
||||
if ids in self.logits_processor:
|
||||
del self.logits_processor[ids]
|
||||
return
|
||||
"""add logits processor to SamplerProcessor"""
|
||||
self._prefill_done_idxs[idx] = False
|
||||
|
||||
if isinstance(future, LogitsProcessorBase):
|
||||
self.logits_processor[ids] = future
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[ids].accept_token(token)
|
||||
elif future.done():
|
||||
self.logits_processor[ids] = future.result()
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[ids].accept_token(token)
|
||||
else:
|
||||
self.logits_processor[ids] = [future, prefill_tokens]
|
||||
|
||||
def update_vocab_mask(self, skip_idx_list: List[int] = []):
|
||||
"""update vocab mask. (cpu-heavy operation)"""
|
||||
if len(self.logits_processor) == 0:
|
||||
if future is None:
|
||||
# normal request without guided_backend
|
||||
self.logits_processors[idx] = None
|
||||
return
|
||||
|
||||
with self.logits_lock:
|
||||
for idx, processor in self.logits_processor.items():
|
||||
if processor is None:
|
||||
del self.logits_processor[idx]
|
||||
continue
|
||||
if len(prefill_tokens) != 0:
|
||||
# first_token from prefill node
|
||||
self._prefill_done_idxs[idx] = True
|
||||
|
||||
if not isinstance(processor, LogitsProcessorBase):
|
||||
future, prefill_tokens = self.logits_processor[idx]
|
||||
self.logits_processor[idx] = future.result()
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[idx].accept_token(token)
|
||||
if future.done():
|
||||
# cached xgrammar
|
||||
self.logits_processors[idx] = future.result()
|
||||
for token in prefill_tokens:
|
||||
self._accept_token(idx, token)
|
||||
else:
|
||||
# async
|
||||
self.logits_processors[idx] = future
|
||||
self._tokens_to_acc[idx] = prefill_tokens
|
||||
|
||||
available_processors = None
|
||||
for processor in self.logits_processor.values():
|
||||
if processor.is_terminated():
|
||||
continue
|
||||
available_processors = processor
|
||||
if available_processors is None:
|
||||
return
|
||||
def should_fill_bitmask(self, idx: int) -> bool:
|
||||
"""
|
||||
Determines whether to fill a bitmask for the logits processor at the given index.
|
||||
|
||||
# allocate token bitmask
|
||||
self.token_bitmask = available_processors.allocate_token_bitmask()
|
||||
Args:
|
||||
idx (int): The index of the logits processor to check
|
||||
|
||||
with self.logits_lock:
|
||||
# fill token bitmask
|
||||
for idx, processor in self.logits_processor.items():
|
||||
if processor.is_terminated() or idx in skip_idx_list:
|
||||
continue
|
||||
Returns:
|
||||
bool: True if the idx request bitmask should be filled
|
||||
|
||||
"""
|
||||
if self.reasoning_parser is not None:
|
||||
if self.logits_processors[idx].enable_reasoning: # <think> guided
|
||||
return True
|
||||
if not self.logits_processors[idx].reasoning_ended:
|
||||
return False
|
||||
return True
|
||||
|
||||
def reset_processor(self, idx: int):
|
||||
"""reset idx"""
|
||||
self._prefill_done_idxs[idx] = False
|
||||
self.logits_processors[idx] = None
|
||||
|
||||
def update_vocab_mask(self, prefill_done_idxs: List[int] = []):
|
||||
"""update vocab mask. (cpu-heavy operation)"""
|
||||
for idx in prefill_done_idxs:
|
||||
if self.logits_processors[idx] is None:
|
||||
continue
|
||||
|
||||
assert not self._prefill_done_idxs[idx]
|
||||
self._prefill_done_idxs[idx] = True
|
||||
if isinstance(self.logits_processors[idx], Future):
|
||||
continue
|
||||
|
||||
for idx, processor in enumerate(self.logits_processors):
|
||||
if processor is None or not self._prefill_done_idxs[idx]:
|
||||
continue
|
||||
# skip, join at apply_token_mask
|
||||
if isinstance(processor, Future):
|
||||
continue
|
||||
if processor.is_terminated:
|
||||
self.reset_processor(idx)
|
||||
continue
|
||||
|
||||
self.accept_tokens_from_prefill_node(idx)
|
||||
|
||||
if self.token_bitmask is None:
|
||||
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
|
||||
|
||||
if self.should_fill_bitmask(idx):
|
||||
processor.fill_token_bitmask(self.token_bitmask, idx)
|
||||
|
||||
def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
"""apply token mask to logits"""
|
||||
if len(self.logits_processor) == 0 or self.token_bitmask is None:
|
||||
return logits
|
||||
def accept_tokens_from_prefill_node(self, idx: int):
|
||||
"""accept prefill token, not future"""
|
||||
if self._tokens_to_acc[idx] is not None:
|
||||
# accept token from prefill node first
|
||||
for token in self._tokens_to_acc[idx]:
|
||||
self._accept_token(idx, token)
|
||||
self._tokens_to_acc[idx] = None
|
||||
|
||||
# self.async_step.result()
|
||||
available_processors = None
|
||||
with self.logits_lock:
|
||||
for processor in self.logits_processor.values():
|
||||
if processor.is_terminated():
|
||||
continue
|
||||
available_processors = processor
|
||||
if available_processors is None:
|
||||
return logits
|
||||
def apply_token_mask(self, logits: paddle.Tensor, prefill_done_idxs: List[int] = []):
|
||||
"""apply token mask to logits"""
|
||||
|
||||
indices = []
|
||||
for idx, processor in self.logits_processor.items():
|
||||
if processor is None or idx in skip_idx_list:
|
||||
for idx, processor in enumerate(self.logits_processors):
|
||||
if processor is None or not self._prefill_done_idxs[idx]:
|
||||
continue
|
||||
if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended:
|
||||
indices.append(idx)
|
||||
|
||||
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices)
|
||||
# compiled done, check idx should fill, fill_token_bitmask done in preprocess
|
||||
if not isinstance(processor, Future):
|
||||
if self.should_fill_bitmask(idx):
|
||||
indices.append(idx)
|
||||
continue
|
||||
|
||||
# is Future, processor async compiled not ready, need join and wait
|
||||
ts = time.time()
|
||||
wait = False
|
||||
if not processor.done():
|
||||
wait = True
|
||||
self.logits_processors[idx] = processor.result()
|
||||
if wait:
|
||||
logger.debug(f"[{idx} join async compile xgrammar, time_cost:{time.time() - ts}]")
|
||||
|
||||
self.accept_tokens_from_prefill_node(idx)
|
||||
# Possible optimization: Extract 'think' content validation from logits_processors,
|
||||
# allowing join operations to complete immediately after 'think' terminates.
|
||||
# Furthermore, the current idx could be skipped, with compilation overhead
|
||||
# estimated at only a few milliseconds.
|
||||
|
||||
# check idx for fill_token_mask
|
||||
if not self.should_fill_bitmask(idx):
|
||||
continue
|
||||
|
||||
indices.append(idx)
|
||||
|
||||
if self.token_bitmask is None:
|
||||
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
|
||||
|
||||
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
|
||||
|
||||
if len(indices) == 0:
|
||||
return logits
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
|
||||
apply_token_mask,
|
||||
)
|
||||
|
||||
return apply_token_mask(logits, self.token_bitmask, indices=indices)
|
||||
|
||||
def _accept_token(self, idx: int, token: int):
|
||||
"""accept token"""
|
||||
if idx not in self.logits_processor:
|
||||
raise ValueError(f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}")
|
||||
|
||||
if self.logits_processor[idx].is_terminated():
|
||||
return
|
||||
if self.reasoning_parser is not None:
|
||||
if not self.logits_processors[idx].enable_reasoning:
|
||||
if not self.logits_processors[idx].reasoning_ended:
|
||||
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
|
||||
self.logits_processors[idx].reasoning_ended = reasoning_ended
|
||||
return
|
||||
|
||||
if (
|
||||
self.reasoning_parser is not None
|
||||
and self.logits_processor[idx].enable_reasoning
|
||||
and not self.logits_processor[idx].reasoning_ended
|
||||
):
|
||||
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
|
||||
self.logits_processor[idx].reasoning_ended = reasoning_ended
|
||||
return
|
||||
if not self.logits_processors[idx].accept_token(token) or self.logits_processors[idx].is_terminated:
|
||||
self.reset_processor(idx)
|
||||
|
||||
self.logits_processor[idx].accept_token(token)
|
||||
|
||||
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
def update_output_tokens(self, next_tokens: paddle.Tensor):
|
||||
"""update output tokens"""
|
||||
if len(self.logits_processor) == 0:
|
||||
if len(self.logits_processors) == 0:
|
||||
return
|
||||
|
||||
token_ids = next_tokens.numpy().tolist()
|
||||
with self.logits_lock:
|
||||
for idx in self.logits_processor.keys():
|
||||
token = token_ids[idx][0]
|
||||
if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list:
|
||||
continue
|
||||
for idx, processor in enumerate(self.logits_processors):
|
||||
if not self._prefill_done_idxs[idx] or processor is None:
|
||||
continue
|
||||
if idx >= len(token_ids):
|
||||
continue
|
||||
token = token_ids[idx][0]
|
||||
if token < 0:
|
||||
self.reset_processor(idx)
|
||||
continue
|
||||
logger.debug(f"[{idx}]accept token{token}")
|
||||
self._accept_token(idx, token)
|
||||
|
||||
self._accept_token(idx, token)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
def pre_process(self, prefill_done_idxs: List[int] = []):
|
||||
"""pre process before running"""
|
||||
# create async operation for guided decoding
|
||||
# TODO: support async
|
||||
self.update_vocab_mask(skip_idx_list)
|
||||
# self.async_step = self.executor.submit(self.update_vocab_mask)
|
||||
self.update_vocab_mask(prefill_done_idxs)
|
||||
|
||||
|
||||
class Sampler(nn.Layer):
|
||||
@@ -224,7 +272,7 @@ class Sampler(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.guided_decoding = GuidedDecoding()
|
||||
self.guided_decoding = GuidedDecoding(fd_config)
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
@@ -240,17 +288,19 @@ class Sampler(nn.Layer):
|
||||
"""set reasoning parser"""
|
||||
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
|
||||
|
||||
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
|
||||
def apply_logits_processor(
|
||||
self, ids: int, future: Future[LogitsProcessorBase] = None, prefill_tokens: List[int] = []
|
||||
):
|
||||
"""apply logits processor to sampler"""
|
||||
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
def pre_process(self, prefill_done_idxs: List[int] = []):
|
||||
"""pre process before running"""
|
||||
self.guided_decoding.pre_process(skip_idx_list)
|
||||
self.guided_decoding.pre_process(prefill_done_idxs)
|
||||
|
||||
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
def post_process(self, next_tokens: paddle.Tensor):
|
||||
"""post process after running"""
|
||||
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
|
||||
self.guided_decoding.update_output_tokens(next_tokens)
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
@@ -343,10 +393,10 @@ class Sampler(nn.Layer):
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
skip_idx_list: List[int] = [],
|
||||
p_done_idxs: List[int] = [],
|
||||
) -> SamplerOutput:
|
||||
""" """
|
||||
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
|
||||
logits = self.guided_decoding.apply_token_mask(logits, p_done_idxs)
|
||||
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
|
||||
Reference in New Issue
Block a user