From 5fc12eddfe6939c0a66dbc0af7b913bfab061a47 Mon Sep 17 00:00:00 2001 From: Daci <15625257+ST-XX@users.noreply.github.com> Date: Fri, 14 Nov 2025 18:05:26 +0800 Subject: [PATCH] [Optimization] xgrammar async compile, multi thread, speed up (#4835) * xgrammar async compile, multi thread, speed up * fix test_sampler.py & pre-commit err * add redis version check && fix request.llm_engine_recv_req_timestamp * xgrammar prefill & decode & v0 * fix test_gpu_prompt_logprobs.py * add test_guided_decoding.py * Update fastdeploy/scheduler/splitwise_scheduler.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix torch xgrammar unittest env --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../guided_decoding/base_guided_decoding.py | 33 +- .../guided_decoding/xgrammar_backend.py | 118 +++--- .../model_executor/layers/sample/sampler.py | 264 ++++++++------ fastdeploy/scheduler/splitwise_scheduler.py | 20 ++ fastdeploy/worker/gcu_model_runner.py | 73 ++-- fastdeploy/worker/gpu_model_runner.py | 83 ++--- fastdeploy/worker/hpu_model_runner.py | 48 +-- fastdeploy/worker/metax_model_runner.py | 80 ++--- tests/layers/test_guided_decoding.py | 338 ++++++++++++++++++ tests/layers/test_sampler.py | 66 +++- tests/woker/test_gpu_prompt_logprobs.py | 60 +++- 11 files changed, 810 insertions(+), 373 deletions(-) create mode 100644 tests/layers/test_guided_decoding.py diff --git a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py index 57fccc3fe..c4e235afc 100644 --- a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py +++ b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py @@ -14,9 +14,10 @@ # limitations under the License. """ +import multiprocessing import os import traceback -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from fastdeploy.config import ErnieArchitectures, FDConfig from fastdeploy.engine.request import Request @@ -135,9 +136,9 @@ class BackendBase: """ def __init__(self, fd_config: FDConfig): - self.cache = {} self.fd_config = fd_config - self.executor = ThreadPoolExecutor() + max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) + self.executor = ThreadPoolExecutor(max_workers=max_workers) self.max_cache_size = 2048 self.reasoning_parser = None @@ -263,7 +264,7 @@ class BackendBase: self, schemata_key: tuple[str, str], enable_thinking: bool = False, - ) -> tuple[LogitsProcessorBase, bool]: + ) -> Future[LogitsProcessorBase]: """ get logits processor by key from cache or create new one. @@ -275,13 +276,8 @@ class BackendBase: - LogitsProcessorBase: The logits processor instance - bool: True if processor was from cache, False if newly created """ - value = self.cache.get(schemata_key, None) - if value: - value_copy = value.copy() - value_copy.enable_reasoning = enable_thinking - return value_copy, True value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking) - return value, False + return value def _get_tokenizer_hf(self): """ @@ -303,7 +299,7 @@ class BackendBase: tokenizer = AutoTokenizer.from_pretrained( self.fd_config.model_config.model, - use_fast=False, + use_fast=True, ) if not isinstance(tokenizer, PreTrainedTokenizerFast): @@ -334,21 +330,6 @@ class BackendBase: except Exception as e: raise Exception(f"Fail to initialize hf tokenizer: {e}, {str(traceback.format_exc())}") - def add_cache(self, schemata_key: tuple[str, str], processor: LogitsProcessorBase) -> None: - """ - add logits processor to cache. - - Args: - schemata_key (tuple[str, str]): Tuple containing processor type and schema string - processor (LogitsProcessorBase): Logits processor instance to cache - - Returns: - None: No return value - """ - if len(self.cache) >= self.max_cache_size: - return - self.cache[schemata_key] = processor.copy() - class BaseChecker: """ diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py index 6681bf95f..f7c246edd 100644 --- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py +++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py @@ -29,6 +29,7 @@ from fastdeploy.model_executor.guided_decoding import ( BaseChecker, LogitsProcessorBase, ) +from fastdeploy.platforms import current_platform from fastdeploy.utils import llm_logger try: @@ -86,6 +87,8 @@ class XGrammarProcessor(LogitsProcessorBase): terminate_without_stop_token=terminate_without_stop_token, override_stop_tokens=override_stop_tokens, ) + # when matcher accept eos_token_id, is_terminated = True + self.is_terminated: bool = False def allocate_token_bitmask(self) -> torch.Tensor: """ @@ -109,40 +112,6 @@ class XGrammarProcessor(LogitsProcessorBase): """ self.matcher.fill_next_token_bitmask(token_bitmask, idx) - def apply_token_mask( - self, - logits: paddle.Tensor, - token_bitmask: torch.Tensor, - indices: Optional[List[int]] = None, - ) -> paddle.Tensor: - """ - Apply the token mask to the logits, modifying probabilities of invalid tokens. - - Args: - logits (paddle.Tensor): The logits tensor to modify - token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens - indices (Optional[List[int]]): Optional list of batch indices to apply mask to - - Returns: - paddle.Tensor: The modified logits tensor - """ - origin_place = logits.place - origin_dtype = logits.dtype - logits = torch.from_numpy(logits.numpy()) - - logits = logits.float() # cpu - apply_token_bitmask_inplace( - logits=logits, - bitmask=token_bitmask.to(logits.device, non_blocking=True), - indices=indices, - ) - - return paddle.to_tensor( - logits.numpy(), - dtype=origin_dtype, - place=origin_place, - ) - def reset(self) -> None: """ Reset the grammar matcher state to initial conditions. @@ -155,23 +124,21 @@ class XGrammarProcessor(LogitsProcessorBase): def accept_token(self, token: int) -> None: """ Validate and accept a generated token against the grammar constraints. + when accept eos_token, is_terminated = True Args: token (int): The token ID to validate - Raises: - AssertionError: If token is not allowed by the grammar """ - assert self.matcher.accept_token(token), f"Failed to accept token {token}" - - def is_terminated(self) -> bool: - """ - Check if the grammar matching process has terminated. - - Returns: - bool: True if matching has terminated, False otherwise - """ - return self.matcher.is_terminated() + if self.is_terminated or self.matcher.is_terminated(): + self.is_terminated = True + return False + if not self.matcher.accept_token(token): + self.matcher.reset() + return False + if self.matcher.is_terminated(): + self.is_terminated = True + return True def copy(self) -> "XGrammarProcessor": """ @@ -216,7 +183,18 @@ class XGrammarBackend(BackendBase): try: tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size) - self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) + llm_logger.info(f"xgrammar_backend.py tokenizer_info={tokenizer_info.dump_metadata()}") + # Read configuration values, fallback to defaults if not set + xgrammar_cfg = getattr(fd_config, "xgrammar_config", {}) + max_threads = getattr(xgrammar_cfg, "max_threads", 8) + cache_enabled = getattr(xgrammar_cfg, "cache_enabled", True) + cache_limit_bytes = getattr(xgrammar_cfg, "cache_limit_bytes", 4 * 1024 * 1024) + self.grammar_compiler = GrammarCompiler( + tokenizer_info=tokenizer_info, + max_threads=max_threads, + cache_enabled=cache_enabled, + cache_limit_bytes=cache_limit_bytes, + ) except Exception as e: raise Exception(f"Failed to load XGrammar tokenizer: {e}") @@ -467,3 +445,49 @@ class XGrammarChecker(BaseChecker): else: # regex is not format return request, None + + +def apply_token_mask( + logits: paddle.Tensor, + token_bitmask: torch.Tensor, + indices: Optional[List[int]] = None, +) -> paddle.Tensor: + """ + Apply the token mask to the logits, modifying probabilities of invalid tokens. + + Args: + logits (paddle.Tensor): The logits tensor to modify + token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens + indices (Optional[List[int]]): Optional list of batch indices to apply mask to + + Returns: + paddle.Tensor: The modified logits tensor + """ + + if current_platform.is_cuda(): + dlpack = paddle.utils.dlpack.to_dlpack(logits) + t_logits = torch.from_dlpack(dlpack) + apply_token_bitmask_inplace( + logits=t_logits, + bitmask=token_bitmask.to(t_logits.device, non_blocking=True), + indices=indices, + ) + dlpack2 = torch.utils.dlpack.to_dlpack(t_logits) + return paddle.utils.dlpack.from_dlpack(dlpack2) + else: + origin_place = logits.place + origin_dtype = logits.dtype + logits = torch.from_numpy(logits.numpy()) + + logits = logits.float() # cpu + apply_token_bitmask_inplace( + logits=logits, + bitmask=token_bitmask.to(logits.device, non_blocking=True), + indices=indices, + ) + + return paddle.to_tensor( + logits.numpy(), + dtype=origin_dtype, + place=origin_place, + ) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 19bbbf301..b03454268 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -14,13 +14,14 @@ # limitations under the License. """ -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional +import time +from concurrent.futures import Future +from typing import Any, List, Optional import paddle import paddle.nn.functional as F from paddle import nn +from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase @@ -66,140 +67,187 @@ class GuidedDecoding: processor for guided decoding. """ - def __init__(self): - self.async_step = None + def __init__(self, fd_config: FDConfig): self.token_bitmask = None - self.logits_processor: Dict[int, Optional[Any]] = dict() - self.executor = ThreadPoolExecutor() - self.logits_lock = threading.Lock() + self.logits_processors: List[Any] = [None] * fd_config.scheduler_config.max_num_seqs self.reasoning_parser = None + self._prefill_done_idxs: List[bool] = [False] * fd_config.scheduler_config.max_num_seqs + # for pd + self._tokens_to_acc: List[None | List[int]] = [None] * fd_config.scheduler_config.max_num_seqs def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): self.reasoning_parser = reasoning_parser def add_logits_processor( self, - ids: int, + idx: int, future: Optional[Any] = None, prefill_tokens: List[int] = [], ): - """add logits processor to GuidedDecoding""" - with self.logits_lock: - if future is None: - if ids in self.logits_processor: - del self.logits_processor[ids] - return + """add logits processor to SamplerProcessor""" + self._prefill_done_idxs[idx] = False - if isinstance(future, LogitsProcessorBase): - self.logits_processor[ids] = future - for token in prefill_tokens: - self.logits_processor[ids].accept_token(token) - elif future.done(): - self.logits_processor[ids] = future.result() - for token in prefill_tokens: - self.logits_processor[ids].accept_token(token) - else: - self.logits_processor[ids] = [future, prefill_tokens] - - def update_vocab_mask(self, skip_idx_list: List[int] = []): - """update vocab mask. (cpu-heavy operation)""" - if len(self.logits_processor) == 0: + if future is None: + # normal request without guided_backend + self.logits_processors[idx] = None return - with self.logits_lock: - for idx, processor in self.logits_processor.items(): - if processor is None: - del self.logits_processor[idx] - continue + if len(prefill_tokens) != 0: + # first_token from prefill node + self._prefill_done_idxs[idx] = True - if not isinstance(processor, LogitsProcessorBase): - future, prefill_tokens = self.logits_processor[idx] - self.logits_processor[idx] = future.result() - for token in prefill_tokens: - self.logits_processor[idx].accept_token(token) + if future.done(): + # cached xgrammar + self.logits_processors[idx] = future.result() + for token in prefill_tokens: + self._accept_token(idx, token) + else: + # async + self.logits_processors[idx] = future + self._tokens_to_acc[idx] = prefill_tokens - available_processors = None - for processor in self.logits_processor.values(): - if processor.is_terminated(): - continue - available_processors = processor - if available_processors is None: - return + def should_fill_bitmask(self, idx: int) -> bool: + """ + Determines whether to fill a bitmask for the logits processor at the given index. - # allocate token bitmask - self.token_bitmask = available_processors.allocate_token_bitmask() + Args: + idx (int): The index of the logits processor to check - with self.logits_lock: - # fill token bitmask - for idx, processor in self.logits_processor.items(): - if processor.is_terminated() or idx in skip_idx_list: - continue + Returns: + bool: True if the idx request bitmask should be filled + """ + if self.reasoning_parser is not None: + if self.logits_processors[idx].enable_reasoning: # guided + return True + if not self.logits_processors[idx].reasoning_ended: + return False + return True + + def reset_processor(self, idx: int): + """reset idx""" + self._prefill_done_idxs[idx] = False + self.logits_processors[idx] = None + + def update_vocab_mask(self, prefill_done_idxs: List[int] = []): + """update vocab mask. (cpu-heavy operation)""" + for idx in prefill_done_idxs: + if self.logits_processors[idx] is None: + continue + + assert not self._prefill_done_idxs[idx] + self._prefill_done_idxs[idx] = True + if isinstance(self.logits_processors[idx], Future): + continue + + for idx, processor in enumerate(self.logits_processors): + if processor is None or not self._prefill_done_idxs[idx]: + continue + # skip, join at apply_token_mask + if isinstance(processor, Future): + continue + if processor.is_terminated: + self.reset_processor(idx) + continue + + self.accept_tokens_from_prefill_node(idx) + + if self.token_bitmask is None: + self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask() + + if self.should_fill_bitmask(idx): processor.fill_token_bitmask(self.token_bitmask, idx) - def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []): - """apply token mask to logits""" - if len(self.logits_processor) == 0 or self.token_bitmask is None: - return logits + def accept_tokens_from_prefill_node(self, idx: int): + """accept prefill token, not future""" + if self._tokens_to_acc[idx] is not None: + # accept token from prefill node first + for token in self._tokens_to_acc[idx]: + self._accept_token(idx, token) + self._tokens_to_acc[idx] = None - # self.async_step.result() - available_processors = None - with self.logits_lock: - for processor in self.logits_processor.values(): - if processor.is_terminated(): - continue - available_processors = processor - if available_processors is None: - return logits + def apply_token_mask(self, logits: paddle.Tensor, prefill_done_idxs: List[int] = []): + """apply token mask to logits""" indices = [] - for idx, processor in self.logits_processor.items(): - if processor is None or idx in skip_idx_list: + for idx, processor in enumerate(self.logits_processors): + if processor is None or not self._prefill_done_idxs[idx]: continue - if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended: - indices.append(idx) - return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices) + # compiled done, check idx should fill, fill_token_bitmask done in preprocess + if not isinstance(processor, Future): + if self.should_fill_bitmask(idx): + indices.append(idx) + continue + + # is Future, processor async compiled not ready, need join and wait + ts = time.time() + wait = False + if not processor.done(): + wait = True + self.logits_processors[idx] = processor.result() + if wait: + logger.debug(f"[{idx} join async compile xgrammar, time_cost:{time.time() - ts}]") + + self.accept_tokens_from_prefill_node(idx) + # Possible optimization: Extract 'think' content validation from logits_processors, + # allowing join operations to complete immediately after 'think' terminates. + # Furthermore, the current idx could be skipped, with compilation overhead + # estimated at only a few milliseconds. + + # check idx for fill_token_mask + if not self.should_fill_bitmask(idx): + continue + + indices.append(idx) + + if self.token_bitmask is None: + self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask() + + self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx) + + if len(indices) == 0: + return logits + from fastdeploy.model_executor.guided_decoding.xgrammar_backend import ( + apply_token_mask, + ) + + return apply_token_mask(logits, self.token_bitmask, indices=indices) def _accept_token(self, idx: int, token: int): """accept token""" - if idx not in self.logits_processor: - raise ValueError(f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}") - if self.logits_processor[idx].is_terminated(): - return + if self.reasoning_parser is not None: + if not self.logits_processors[idx].enable_reasoning: + if not self.logits_processors[idx].reasoning_ended: + reasoning_ended = self.reasoning_parser.is_reasoning_end([token]) + self.logits_processors[idx].reasoning_ended = reasoning_ended + return - if ( - self.reasoning_parser is not None - and self.logits_processor[idx].enable_reasoning - and not self.logits_processor[idx].reasoning_ended - ): - reasoning_ended = self.reasoning_parser.is_reasoning_end([token]) - self.logits_processor[idx].reasoning_ended = reasoning_ended - return + if not self.logits_processors[idx].accept_token(token) or self.logits_processors[idx].is_terminated: + self.reset_processor(idx) - self.logits_processor[idx].accept_token(token) - - def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): + def update_output_tokens(self, next_tokens: paddle.Tensor): """update output tokens""" - if len(self.logits_processor) == 0: + if len(self.logits_processors) == 0: return token_ids = next_tokens.numpy().tolist() - with self.logits_lock: - for idx in self.logits_processor.keys(): - token = token_ids[idx][0] - if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list: - continue + for idx, processor in enumerate(self.logits_processors): + if not self._prefill_done_idxs[idx] or processor is None: + continue + if idx >= len(token_ids): + continue + token = token_ids[idx][0] + if token < 0: + self.reset_processor(idx) + continue + logger.debug(f"[{idx}]accept token{token}") + self._accept_token(idx, token) - self._accept_token(idx, token) - - def pre_process(self, skip_idx_list: List[int] = []): + def pre_process(self, prefill_done_idxs: List[int] = []): """pre process before running""" - # create async operation for guided decoding - # TODO: support async - self.update_vocab_mask(skip_idx_list) - # self.async_step = self.executor.submit(self.update_vocab_mask) + self.update_vocab_mask(prefill_done_idxs) class Sampler(nn.Layer): @@ -224,7 +272,7 @@ class Sampler(nn.Layer): else: raise NotImplementedError - self.guided_decoding = GuidedDecoding() + self.guided_decoding = GuidedDecoding(fd_config) self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode # Can only be created when fd_config.early_stopper_config.enable_early_stop = True if ( @@ -240,17 +288,19 @@ class Sampler(nn.Layer): """set reasoning parser""" self.guided_decoding.apply_reasoning_parser(reasoning_parser) - def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []): + def apply_logits_processor( + self, ids: int, future: Future[LogitsProcessorBase] = None, prefill_tokens: List[int] = [] + ): """apply logits processor to sampler""" self.guided_decoding.add_logits_processor(ids, future, prefill_tokens) - def pre_process(self, skip_idx_list: List[int] = []): + def pre_process(self, prefill_done_idxs: List[int] = []): """pre process before running""" - self.guided_decoding.pre_process(skip_idx_list) + self.guided_decoding.pre_process(prefill_done_idxs) - def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): + def post_process(self, next_tokens: paddle.Tensor): """post process after running""" - self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list) + self.guided_decoding.update_output_tokens(next_tokens) def compute_logprobs( self, @@ -343,10 +393,10 @@ class Sampler(nn.Layer): self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, - skip_idx_list: List[int] = [], + p_done_idxs: List[int] = [], ) -> SamplerOutput: """ """ - logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list) + logits = self.guided_decoding.apply_token_mask(logits, p_done_idxs) num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: diff --git a/fastdeploy/scheduler/splitwise_scheduler.py b/fastdeploy/scheduler/splitwise_scheduler.py index 7c404c891..94c947ea4 100644 --- a/fastdeploy/scheduler/splitwise_scheduler.py +++ b/fastdeploy/scheduler/splitwise_scheduler.py @@ -706,10 +706,30 @@ class InferScheduler: self.reqs_queue = deque() self.writers = [] + def check_redis_version(self): + # Get Redis version information + redis_info = self.client.info() + redis_version = redis_info.get("redis_version", "") + version_parts = [int(x) for x in redis_version.split(".")] + + # Redis 6.2 and above versions support RPOP with count parameter + assert ( + version_parts[0] >= 6 + ), f"Redis major version too low: {version_parts[0]}. Please upgrade to Redis 6.2+ to support batch RPOP operations." + assert ( + version_parts[1] >= 2 if version_parts[0] == 6 else True + ), f"Redis version {redis_version} too low. Please upgrade to Redis 6.2+ to support batch RPOP operations." + + logger.info(f"Redis version {redis_version} detected. Using native batch RPOP.") + def start(self, role, host, disaggregated): """ start backup threads """ + + # Check Redis version first + self.check_redis_version() + for i in range(self.writer_parallel): writer = ResultWriter(self.client, i, self.writer_batch_size, self.ttl) writer.start() diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index d9401a51a..3444cc7dd 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -31,9 +31,6 @@ from fastdeploy.model_executor.graph_optimization.utils import ( sot_warmup_guard, ) from fastdeploy.model_executor.guided_decoding import get_guided_backend -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( - LogitsProcessorBase, -) from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, @@ -182,7 +179,7 @@ class GCUModelRunner(ModelRunnerBase): or request.guided_grammar is not None ): logits_info, schemata_key = self._init_logits_processor(request) - request.logits_processor, request.logits_cached = logits_info + request.logits_processor = logits_info request.schemata_key = schemata_key # Is Decode Node @@ -925,29 +922,36 @@ class GCUModelRunner(ModelRunnerBase): logger.info(f"SOT warmup the model with the batch size:{batch_size}") logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds") - def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None): + def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int): """ - Get the index of the request that needs to be skipped during execution. - Args: - model_forward_batch: A list of requests to be executed by this runner. - Returns: - A list of indices corresponding to the requests that need to be skipped. + Returns indices for guided decoding. + When Prefill is done, async compiled logits_processor must be joined. """ - skip_idx_list = [] - if not self.cache_config.enable_chunked_prefill or self.guided_backend is None: - return skip_idx_list + if self.guided_backend is None: + return [] - for task in model_forward_batch: - if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info): - continue - skip_idx_list.append(task.idx) + prefill_done_idxs = [] + for idx in range(0, num_running_requests): + if self.share_inputs["step_idx"][idx] == 0: + prefill_done_idxs.append(idx) - for task in self.restore_chunked_prefill_request.values(): - if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info): - continue - skip_idx_list.append(task.idx) + if self.cache_config.enable_chunked_prefill: + if model_forward_batch is not None: + for task in model_forward_batch: + # new Request with ChunkPrefill, unfinished, store + if task.chunk_idx < len(task.prefill_chunk_info): + if task.request_id not in self.restore_chunked_prefill_request: + self.restore_chunked_prefill_request[task.request_id] = task - return skip_idx_list + for id, task in list(self.restore_chunked_prefill_request.items()): + # unfinished, remove + if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs: + prefill_done_idxs.remove(task.idx) + # finished, add + if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs: + prefill_done_idxs.append(task.idx) + + return prefill_done_idxs def execute_model( self, @@ -971,9 +975,9 @@ class GCUModelRunner(ModelRunnerBase): return None # 1. Prepare inputs of model and sampler. - skip_idx_list = self._get_skip_idx(model_forward_batch) + p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests) self._prepare_inputs() - self.sampler.pre_process(skip_idx_list) + self.sampler.pre_process(p_done_idxs) # 2. Padding inputs for cuda graph @@ -1009,7 +1013,7 @@ class GCUModelRunner(ModelRunnerBase): sampler_output = self.sampler( logits, self.sampling_metadata, - skip_idx_list, + p_done_idxs, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) @@ -1081,28 +1085,9 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED self._update_chunked_prefill(model_forward_batch) - self._add_cache(model_forward_batch) self.seq_lens_this_time_buffer.copy_(self.share_inputs["seq_lens_this_time"], False) return None - def _add_cache(self, model_forward_batch) -> None: - """ - Add cache for guided decoding. - """ - if self.guided_backend is None: - return - - for request in model_forward_batch: - logits_cached = request.get("logits_cached", None) - if logits_cached is None or logits_cached: - continue - - request.logits_cached = True - if isinstance(request.logits_processor, LogitsProcessorBase): - self.guided_backend.add_cache(request.schemata_key, request.logits_processor) - else: - self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result()) - def _execute_empty_input(self) -> None: """ In certain scenarios, such as during EP, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8d88cc240..5b30babff 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -17,6 +17,7 @@ import os import queue import time +from concurrent.futures import Future from threading import Thread from typing import List, Optional, cast @@ -287,7 +288,7 @@ class GPUModelRunner(ModelRunnerBase): else: self.proposer = None - def _init_logits_processor(self, request): + def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]: """ init logits processor for guided decoding """ @@ -307,7 +308,7 @@ class GPUModelRunner(ModelRunnerBase): return ( self.guided_backend.get_logits_processor( schemata_key=schemata_key, - enable_thinking=True, + enable_thinking=False, # TODO cfg ), schemata_key, ) @@ -696,6 +697,7 @@ class GPUModelRunner(ModelRunnerBase): length = len(request.prompt_token_ids) assert length > 0, "The prompt requested must not be empty." + logits_info = None prefill_tokens = [] if ( request.guided_json is not None @@ -704,7 +706,6 @@ class GPUModelRunner(ModelRunnerBase): or request.guided_grammar is not None ): logits_info, schemata_key = self._init_logits_processor(request) - request.logits_processor, request.logits_cached = logits_info request.schemata_key = schemata_key # Is Decode Node @@ -874,7 +875,7 @@ class GPUModelRunner(ModelRunnerBase): else: self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 - self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens) + self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) self.share_inputs["not_need_stop"][0] = True @@ -2006,34 +2007,36 @@ class GPUModelRunner(ModelRunnerBase): logger.info(f"SOT warmup the model with the batch size:{batch_size}") logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds") - def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None): + def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int): """ - Get the index of the request that needs to be skipped during execution. - Args: - model_forward_batch: A list of requests to be executed by this runner. - Returns: - A list of indices corresponding to the requests that need to be skipped. + Get indices for guided decoding. + When Prefill is done, async compiled logits_processor must be joined. """ - if ( - not self.cache_config.enable_chunked_prefill - or self.guided_backend is None - or model_forward_batch is None - or envs.ENABLE_V1_KVCACHE_SCHEDULER - ): + if self.guided_backend is None: return [] - skip_idx_list = [] - for task in model_forward_batch: - if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info): - continue - skip_idx_list.append(task.idx) + prefill_done_idxs = [] + for idx in range(0, num_running_requests): + if self.share_inputs["step_idx"][idx] == 0: + prefill_done_idxs.append(idx) - for task in self.restore_chunked_prefill_request.values(): - if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info): - continue - skip_idx_list.append(task.idx) + if self.cache_config.enable_chunked_prefill: + if model_forward_batch is not None: + for task in model_forward_batch: + # new Request with ChunkPrefill, unfinished, store + if task.chunk_idx < len(task.prefill_chunk_info): + if task.request_id not in self.restore_chunked_prefill_request: + self.restore_chunked_prefill_request[task.request_id] = task - return skip_idx_list + for id, task in list(self.restore_chunked_prefill_request.items()): + # unfinished, remove + if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs: + prefill_done_idxs.remove(task.idx) + # finished, add + if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs: + prefill_done_idxs.append(task.idx) + + return prefill_done_idxs def execute_model( self, @@ -2050,9 +2053,10 @@ class GPUModelRunner(ModelRunnerBase): num_running_requests: batch_size """ # 1. Prepare inputs of model and sampler. - skip_idx_list = self._get_skip_idx(model_forward_batch) + p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests) + self._prepare_inputs() - self.sampler.pre_process(skip_idx_list) + self.sampler.pre_process(p_done_idxs) # 1.1 Update state of logits processor for proc in self.sampling_metadata.logits_processors: @@ -2157,7 +2161,7 @@ class GPUModelRunner(ModelRunnerBase): sampler_output = self.sampler( logits, self.sampling_metadata, - skip_idx_list, + p_done_idxs, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -2244,7 +2248,7 @@ class GPUModelRunner(ModelRunnerBase): line_break_id=self.model_config.line_break_id, ) if self.guided_backend is not None and sampler_output is not None: - self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) + self.sampler.post_process(sampler_output.sampled_token_ids) # 6. Speculative decode if self.speculative_decoding: @@ -2268,7 +2272,6 @@ class GPUModelRunner(ModelRunnerBase): ) self._update_chunked_prefill(model_forward_batch) - self._add_cache(model_forward_batch) elif self.speculative_decoding: speculate_schedule_cache( self.share_inputs["draft_tokens"], @@ -2325,24 +2328,6 @@ class GPUModelRunner(ModelRunnerBase): return pooler_output - def _add_cache(self, model_forward_batch) -> None: - """ - Add cache for guided decoding. - """ - if self.guided_backend is None or model_forward_batch is None: - return - - for request in model_forward_batch: - logits_cached = request.get("logits_cached", None) - if logits_cached is None or logits_cached: - continue - - request.logits_cached = True - if isinstance(request.logits_processor, LogitsProcessorBase): - self.guided_backend.add_cache(request.schemata_key, request.logits_processor) - else: - self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result()) - def _execute_empty_input(self) -> None: """ In certain scenarios, such as during EP, diff --git a/fastdeploy/worker/hpu_model_runner.py b/fastdeploy/worker/hpu_model_runner.py index e1cc1e3e7..c890e4f2c 100644 --- a/fastdeploy/worker/hpu_model_runner.py +++ b/fastdeploy/worker/hpu_model_runner.py @@ -30,9 +30,6 @@ from fastdeploy.engine.request import Request # from fastdeploy.spec_decode import MTPProposer, NgramProposer from fastdeploy.model_executor.forward_meta import HPUForwardMeta from fastdeploy.model_executor.guided_decoding import get_guided_backend -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( - LogitsProcessorBase, -) from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, @@ -442,7 +439,7 @@ class HPUModelRunner(ModelRunnerBase): or request.guided_grammar is not None ): logits_info, schemata_key = self._init_logits_processor(request) - request.logits_processor, request.logits_cached = logits_info + request.logits_processor = logits_info request.schemata_key = schemata_key # Is Decode Node @@ -1179,30 +1176,6 @@ class HPUModelRunner(ModelRunnerBase): time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") - def _get_skip_idx(self, model_forward_batch): - """ - Get the index of the request that needs to be skipped during execution. - Args: - model_forward_batch: A list of requests to be executed by this runner. - Returns: - A list of indices corresponding to the requests that need to be skipped. - """ - skip_idx_list = [] - if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None: - return skip_idx_list - - for task in model_forward_batch: - if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info): - continue - skip_idx_list.append(task.idx) - - for task in self.restore_chunked_prefill_request.values(): - if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info): - continue - skip_idx_list.append(task.idx) - - return skip_idx_list - def execute_model( self, model_forward_batch: Optional[List[Request]] = None, @@ -1332,30 +1305,11 @@ class HPUModelRunner(ModelRunnerBase): execution_time = (end_time - start_time) * 1000 hpu_model_runner_profile_logger.info(f"StepPaddle execution time(ms): {execution_time}, BT={real_bs}") self._update_chunked_prefill(model_forward_batch) - self._add_cache(model_forward_batch) if int(os.environ.get("HABANA_PROFILE", 0)) == 1: self.prof.step() return None - def _add_cache(self, model_forward_batch) -> None: - """ - Add cache for guided decoding. - """ - if self.guided_backend is None: - return - - for request in model_forward_batch: - logits_cached = request.get("logits_cached", None) - if logits_cached is None or logits_cached: - continue - - request.logits_cached = True - if isinstance(request.logits_processor, LogitsProcessorBase): - self.guided_backend.add_cache(request.schemata_key, request.logits_processor) - else: - self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result()) - def _execute_empty_input(self) -> None: """ In certain scenarios, such as during EP, diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 75be437e7..b346f7be6 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -38,10 +38,7 @@ from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, sot_warmup_guard, ) -from fastdeploy.model_executor.guided_decoding import ( - LogitsProcessorBase, - get_guided_backend, -) +from fastdeploy.model_executor.guided_decoding import get_guided_backend from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, @@ -507,7 +504,7 @@ class MetaxModelRunner(ModelRunnerBase): or request.guided_grammar is not None ): logits_info, schemata_key = self._init_logits_processor(request) - request.logits_processor, request.logits_cached = logits_info + request.logits_processor = logits_info request.schemata_key = schemata_key # Is Decode Node @@ -1761,34 +1758,36 @@ class MetaxModelRunner(ModelRunnerBase): logger.info(f"SOT warmup the model with the batch size:{batch_size}") logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds") - def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None): + def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int): """ - Get the index of the request that needs to be skipped during execution. - Args: - model_forward_batch: A list of requests to be executed by this runner. - Returns: - A list of indices corresponding to the requests that need to be skipped. + Get indices for guided decoding. + When Prefill is done, async compiled logits_processor must be joined. """ - if ( - not self.cache_config.enable_chunked_prefill - or self.guided_backend is None - or model_forward_batch is None - or envs.ENABLE_V1_KVCACHE_SCHEDULER - ): + if self.guided_backend is None: return [] - skip_idx_list = [] - for task in model_forward_batch: - if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info): - continue - skip_idx_list.append(task.idx) + prefill_done_idxs = [] + for idx in range(0, num_running_requests): + if self.share_inputs["step_idx"][idx] == 0: + prefill_done_idxs.append(idx) - for task in self.restore_chunked_prefill_request.values(): - if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info): - continue - skip_idx_list.append(task.idx) + if self.cache_config.enable_chunked_prefill: + if model_forward_batch is not None: + for task in model_forward_batch: + # new Request with ChunkPrefill, unfinished, store + if task.chunk_idx < len(task.prefill_chunk_info): + if task.request_id not in self.restore_chunked_prefill_request: + self.restore_chunked_prefill_request[task.request_id] = task - return skip_idx_list + for id, task in list(self.restore_chunked_prefill_request.items()): + # unfinished, remove + if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs: + prefill_done_idxs.remove(task.idx) + # finished, add + if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs: + prefill_done_idxs.append(task.idx) + + return prefill_done_idxs def execute_model( self, @@ -1805,9 +1804,9 @@ class MetaxModelRunner(ModelRunnerBase): num_running_requests: batch_size """ # 1. Prepare inputs of model and sampler. - skip_idx_list = self._get_skip_idx(model_forward_batch) + p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests) self._prepare_inputs() - self.sampler.pre_process(skip_idx_list) + self.sampler.pre_process(p_done_idxs) # NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, @@ -1864,7 +1863,7 @@ class MetaxModelRunner(ModelRunnerBase): sampler_output = self.sampler( logits, self.sampling_metadata, - skip_idx_list, + p_done_idxs, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -1949,7 +1948,7 @@ class MetaxModelRunner(ModelRunnerBase): line_break_id=self.model_config.line_break_id, ) if self.guided_backend is not None and sampler_output is not None: - self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) + self.sampler.post_process(sampler_output.sampled_token_ids) # 6. Speculative decode if self.speculative_decoding: @@ -1973,7 +1972,6 @@ class MetaxModelRunner(ModelRunnerBase): ) self._update_chunked_prefill(model_forward_batch) - self._add_cache(model_forward_batch) elif self.speculative_decoding: speculate_schedule_cache( self.share_inputs["draft_tokens"], @@ -2000,24 +1998,6 @@ class MetaxModelRunner(ModelRunnerBase): ) return None - def _add_cache(self, model_forward_batch) -> None: - """ - Add cache for guided decoding. - """ - if self.guided_backend is None or model_forward_batch is None: - return - - for request in model_forward_batch: - logits_cached = request.get("logits_cached", None) - if logits_cached is None or logits_cached: - continue - - request.logits_cached = True - if isinstance(request.logits_processor, LogitsProcessorBase): - self.guided_backend.add_cache(request.schemata_key, request.logits_processor) - else: - self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result()) - def _execute_empty_input(self) -> None: """ In certain scenarios, such as during EP, diff --git a/tests/layers/test_guided_decoding.py b/tests/layers/test_guided_decoding.py new file mode 100644 index 000000000..1f71fcb34 --- /dev/null +++ b/tests/layers/test_guided_decoding.py @@ -0,0 +1,338 @@ +""" +测试GuidedDecoding类的单元测试 +""" + +import sys +import unittest +from concurrent.futures import Future +from unittest.mock import MagicMock, Mock, patch + +import paddle + +mock_torch = MagicMock() +mock_xgrammar = MagicMock() +sys.modules["torch"] = mock_torch +sys.modules["xgrammar"] = mock_xgrammar + +from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase +from fastdeploy.model_executor.layers.sample.sampler import GuidedDecoding +from fastdeploy.reasoning import ReasoningParser + + +class TestGuidedDecoding(unittest.TestCase): + """Test cases for GuidedDecoding class.""" + + def setUp(self): + """Setup for each test case.""" + # 创建一个基本的FDConfig对象 + self.fd_config = Mock() + self.fd_config.scheduler_config.max_num_seqs = 5 + + # 创建GuidedDecoding对象 + self.guided_decoding = GuidedDecoding(self.fd_config) + + # 创建一个模拟的LogitsProcessorBase + self.mock_processor = Mock(spec=LogitsProcessorBase) + self.mock_processor.is_terminated = False + self.mock_processor.reasoning_ended = True + self.mock_processor.enable_reasoning = False + + # 模拟allocate_token_bitmask方法返回一个假的bitmask + self.mock_processor.allocate_token_bitmask.return_value = paddle.zeros([5, 10], dtype="int32") + + # 模拟fill_token_bitmask方法 + self.mock_processor.fill_token_bitmask.return_value = None + + # 模拟accept_token方法返回True + self.mock_processor.accept_token.return_value = True + + def test_init(self): + """Test initialization.""" + self.assertIsNone(self.guided_decoding.token_bitmask) + self.assertEqual(len(self.guided_decoding.logits_processors), 5) + self.assertIsNone(self.guided_decoding.reasoning_parser) + self.assertEqual(len(self.guided_decoding._prefill_done_idxs), 5) + self.assertEqual(len(self.guided_decoding._tokens_to_acc), 5) + + def test_apply_reasoning_parser(self): + """Test apply_reasoning_parser method.""" + mock_parser = Mock(spec=ReasoningParser) + self.guided_decoding.apply_reasoning_parser(mock_parser) + self.assertEqual(self.guided_decoding.reasoning_parser, mock_parser) + + def test_add_logits_processor_no_future(self): + """Test add_logits_processor method without future.""" + self.guided_decoding.add_logits_processor(0, None, []) + self.assertFalse(self.guided_decoding._prefill_done_idxs[0]) + self.assertIsNone(self.guided_decoding.logits_processors[0]) + + def test_add_logits_processor_with_prefill_tokens(self): + """Test add_logits_processor method with prefill tokens.""" + # 创建模拟Future对象 + mock_future = Mock() + mock_future.done.return_value = True + mock_future.result.return_value = self.mock_processor + + prefill_tokens = [1, 2, 3] + self.guided_decoding.add_logits_processor(0, mock_future, prefill_tokens) + + self.assertTrue(self.guided_decoding._prefill_done_idxs[0]) + self.assertEqual(self.guided_decoding.logits_processors[0], self.mock_processor) + self.mock_processor.accept_token.assert_any_call(1) + self.mock_processor.accept_token.assert_any_call(2) + self.mock_processor.accept_token.assert_any_call(3) + + def test_add_logits_processor_with_async_future(self): + """Test add_logits_processor method with async future.""" + # 创建模拟Future对象 + mock_future = Mock() + mock_future.done.return_value = False + + prefill_tokens = [1, 2, 3] + self.guided_decoding.add_logits_processor(0, mock_future, prefill_tokens) + + self.assertTrue(self.guided_decoding._prefill_done_idxs[0]) + self.assertEqual(self.guided_decoding.logits_processors[0], mock_future) + self.assertEqual(self.guided_decoding._tokens_to_acc[0], prefill_tokens) + + def test_should_fill_bitmask_no_reasoning_parser(self): + """Test should_fill_bitmask method with no reasoning parser.""" + self.guided_decoding.logits_processors[0] = self.mock_processor + self.assertTrue(self.guided_decoding.should_fill_bitmask(0)) + + def test_should_fill_bitmask_with_reasoning_parser(self): + """Test should_fill_bitmask method with reasoning parser.""" + mock_parser = Mock(spec=ReasoningParser) + self.guided_decoding.reasoning_parser = mock_parser + + # 测试 enable_reasoning=True 的情况 + self.mock_processor.enable_reasoning = True + self.guided_decoding.logits_processors[0] = self.mock_processor + self.assertTrue(self.guided_decoding.should_fill_bitmask(0)) + + # 测试 enable_reasoning=False, reasoning_ended=False 的情况 + self.mock_processor.enable_reasoning = False + self.mock_processor.reasoning_ended = False + self.assertFalse(self.guided_decoding.should_fill_bitmask(0)) + + # 测试 enable_reasoning=False, reasoning_ended=True 的情况 + self.mock_processor.reasoning_ended = True + self.assertTrue(self.guided_decoding.should_fill_bitmask(0)) + + def test_reset_processor(self): + """Test reset_processor method.""" + self.guided_decoding.logits_processors[0] = self.mock_processor + self.guided_decoding._prefill_done_idxs[0] = True + + self.guided_decoding.reset_processor(0) + + self.assertFalse(self.guided_decoding._prefill_done_idxs[0]) + self.assertIsNone(self.guided_decoding.logits_processors[0]) + + def test_update_vocab_mask_with_new_prefill_done(self): + """Test update_vocab_mask method with new prefill_done_idxs.""" + # 设置索引0的处理器 + self.guided_decoding.logits_processors[0] = self.mock_processor + self.guided_decoding._prefill_done_idxs[0] = False + + # 调用update_vocab_mask并标记索引0为已完成 + self.guided_decoding.update_vocab_mask([0]) + + # 验证_prefill_done_idxs[0]已更新 + self.assertTrue(self.guided_decoding._prefill_done_idxs[0]) + + # 验证fill_token_bitmask被调用 + self.mock_processor.fill_token_bitmask.assert_called_once() + + def test_update_vocab_mask_with_future_processor(self): + """Test update_vocab_mask method with future processor.""" + # 创建模拟Future对象 + mock_future = Mock() + + # 设置索引0的处理器为Future + self.guided_decoding.logits_processors[0] = mock_future + self.guided_decoding._prefill_done_idxs[0] = True + + # 调用update_vocab_mask + self.guided_decoding.update_vocab_mask([]) + + # 验证fill_token_bitmask没有被调用(因为处理器是Future) + self.mock_processor.fill_token_bitmask.assert_not_called() + + def test_accept_tokens_from_prefill_node(self): + """Test accept_tokens_from_prefill_node method.""" + # 设置索引0的处理器和待接受的tokens + self.guided_decoding.logits_processors[0] = self.mock_processor + self.guided_decoding._tokens_to_acc[0] = [1, 2, 3] + + # 调用accept_tokens_from_prefill_node + self.guided_decoding.accept_tokens_from_prefill_node(0) + + # 验证accept_token被调用了3次 + self.assertEqual(self.mock_processor.accept_token.call_count, 3) + self.mock_processor.accept_token.assert_any_call(1) + self.mock_processor.accept_token.assert_any_call(2) + self.mock_processor.accept_token.assert_any_call(3) + + # 验证_tokens_to_acc[0]已被重置 + self.assertIsNone(self.guided_decoding._tokens_to_acc[0]) + + @patch("fastdeploy.model_executor.guided_decoding.xgrammar_backend.apply_token_mask") + def test_apply_token_mask(self, mock_apply_token_mask): + """Test apply_token_mask method.""" + # 创建测试数据 + logits = paddle.zeros([5, 10], dtype="float32") + mock_apply_token_mask.return_value = paddle.ones([5, 10], dtype="float32") + + # 设置索引0的处理器 + self.guided_decoding.logits_processors[0] = self.mock_processor + self.guided_decoding._prefill_done_idxs[0] = True + + # 调用apply_token_mask + result = self.guided_decoding.apply_token_mask(logits, []) + + # 验证fill_token_bitmask没有被调用,非 Future + self.mock_processor.fill_token_bitmask.assert_not_called() + + # 验证apply_token_mask被调用 + mock_apply_token_mask.assert_called_once() + + # 验证返回值 + self.assertTrue((result == paddle.ones([5, 10], dtype="float32")).all()) + + def test_apply_token_mask_with_future_processor(self): + """Test apply_token_mask method with future processor.""" + # 创建测试数据 + logits = paddle.zeros([5, 10], dtype="float32") + + # 创建模拟Future对象 + mock_future = Mock(spec=Future) + mock_future.done.return_value = True + mock_future.result.return_value = self.mock_processor + + # 设置索引0的处理器为Future + self.guided_decoding.logits_processors[0] = mock_future + + self.guided_decoding._prefill_done_idxs[0] = True + self.assertTrue(self.guided_decoding._prefill_done_idxs[0]) + self.assertIsNotNone(self.guided_decoding.logits_processors[0]) + self.assertTrue(isinstance(self.guided_decoding.logits_processors[0], Future)) + self.guided_decoding._tokens_to_acc[0] = [1, 2, 3] + + # 模拟patch apply_token_mask + with patch( + "fastdeploy.model_executor.guided_decoding.xgrammar_backend.apply_token_mask" + ) as mock_apply_token_mask: + mock_apply_token_mask.return_value = paddle.ones([5, 10], dtype="float32") + + # 调用apply_token_mask + self.guided_decoding.apply_token_mask(logits, []) + + # 验证Future.result被调用 + mock_future.result.assert_called_once() + + # 验证accept_token被调用了3次 + self.assertEqual(self.mock_processor.accept_token.call_count, 3) + + # 验证_tokens_to_acc[0]已被重置 + self.assertIsNone(self.guided_decoding._tokens_to_acc[0]) + + def test_accept_token(self): + """Test _accept_token method.""" + # 设置索引0的处理器 + self.guided_decoding.logits_processors[0] = self.mock_processor + + # 调用_accept_token + self.guided_decoding._accept_token(0, 1) + + # 验证accept_token被调用 + self.mock_processor.accept_token.assert_called_once_with(1) + + def test_accept_token_with_reasoning_parser(self): + """Test _accept_token method with reasoning parser.""" + # 创建模拟ReasoningParser + mock_parser = Mock(spec=ReasoningParser) + mock_parser.is_reasoning_end.return_value = True + self.guided_decoding.reasoning_parser = mock_parser + + # 设置索引0的处理器 + self.mock_processor.enable_reasoning = False + self.mock_processor.reasoning_ended = False + self.guided_decoding.logits_processors[0] = self.mock_processor + + # 调用_accept_token + self.guided_decoding._accept_token(0, 1) + + # 验证is_reasoning_end被调用 + mock_parser.is_reasoning_end.assert_called_once_with([1]) + + # 验证reasoning_ended已更新 + self.assertTrue(self.mock_processor.reasoning_ended) + + # 验证accept_token没有被调用(因为reasoning_ended刚被设置为True) + self.mock_processor.accept_token.assert_not_called() + + def test_accept_token_processor_terminated(self): + """Test _accept_token method when processor is terminated.""" + # 设置索引0的处理器,并让accept_token返回False + self.mock_processor.accept_token.return_value = False + self.guided_decoding.logits_processors[0] = self.mock_processor + + # 调用_accept_token + self.guided_decoding._accept_token(0, 1) + + # 验证处理器被重置 + self.assertIsNone(self.guided_decoding.logits_processors[0]) + + def test_update_output_tokens(self): + """Test update_output_tokens method.""" + # 创建测试数据 + next_tokens = paddle.to_tensor([[1], [2], [3], [4], [5]]) + + # 设置索引0和1的处理器 + self.guided_decoding.logits_processors[0] = self.mock_processor + self.guided_decoding.logits_processors[1] = self.mock_processor + self.guided_decoding._prefill_done_idxs[0] = True + self.guided_decoding._prefill_done_idxs[1] = True + + # 调用update_output_tokens + self.guided_decoding.update_output_tokens(next_tokens) + + # 验证accept_token被调用了两次 + self.assertEqual(self.mock_processor.accept_token.call_count, 2) + self.mock_processor.accept_token.assert_any_call(1) + self.mock_processor.accept_token.assert_any_call(2) + + def test_update_output_tokens_with_negative_token(self): + """Test update_output_tokens method with negative token.""" + # 创建测试数据,包含负值 + next_tokens = paddle.to_tensor([[-1], [2]]) + + # 设置索引0和1的处理器 + self.guided_decoding.logits_processors[0] = self.mock_processor + self.guided_decoding.logits_processors[1] = self.mock_processor + self.guided_decoding._prefill_done_idxs[0] = True + self.guided_decoding._prefill_done_idxs[1] = True + + # 调用update_output_tokens + self.guided_decoding.update_output_tokens(next_tokens) + + # 验证索引0的处理器被重置 + self.assertIsNone(self.guided_decoding.logits_processors[0]) + + # 验证索引1的处理器的accept_token被调用 + self.mock_processor.accept_token.assert_called_once_with(2) + + def test_pre_process(self): + """Test pre_process method.""" + # 模拟update_vocab_mask方法 + with patch.object(self.guided_decoding, "update_vocab_mask") as mock_update_vocab_mask: + # 调用pre_process + self.guided_decoding.pre_process([0, 1]) + + # 验证update_vocab_mask被调用 + mock_update_vocab_mask.assert_called_once_with([0, 1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/layers/test_sampler.py b/tests/layers/test_sampler.py index a40f1dcfb..46f1213b4 100644 --- a/tests/layers/test_sampler.py +++ b/tests/layers/test_sampler.py @@ -14,11 +14,23 @@ # limitations under the License. """ +import json +import os + import paddle import paddle.nn.functional as F +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + LoadConfig, + ModelConfig, + ParallelConfig, +) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler +from fastdeploy.scheduler import SchedulerConfig def _create_fake_logits(batch_size: int, vocab_size: int) -> paddle.Tensor: @@ -67,13 +79,60 @@ def _create_default_sampling_metadata( return fake_sampling_metadata +def build_config_json() -> str: + config_dict = { + "architectures": ["Qwen3MoeForCausalLM"], + "hidden_size": 7168, + "moe_intermediate_size": 1, + "moe_num_experts": 1, + "moe_k": 1, + "hidden_act": "silu", + "num_attention_heads": 64, + "dtype": "bfloat16", + } + + tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}" + os.makedirs(tmp_dir, exist_ok=True) + with open(f"./{tmp_dir}/config.json", "w") as f: + json.dump(config_dict, f) + model_name_or_path = os.path.join(os.getcwd(), tmp_dir) + print("model_name_or_path", model_name_or_path) + return model_name_or_path + + +def get_fd_config(batch_size: int): + fd_config = FDConfig( + model_config=ModelConfig( + { + "model": build_config_json(), + "max_model_len": 2048, + } + ), + parallel_config=ParallelConfig( + { + "tensor_parallel_size": 1, + "expert_parallel_size": 1, + "expert_parallel_rank": 0, + "data_parallel_size": 1, + } + ), + # quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]), + scheduler_config=SchedulerConfig({"max_num_seqs": batch_size}), + cache_config=CacheConfig({}), + graph_opt_config=GraphOptimizationConfig({}), + load_config=LoadConfig({}), + ips="0.0.0.0", + ) + return fd_config + + def test_sampler(): batch_size = 32 vocab_size = 1024 min_seq_len = 1 max_seq_len = 1024 - sampler = Sampler() + sampler = Sampler(get_fd_config(batch_size)) logits = _create_fake_logits(batch_size, vocab_size) sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len) next_tokens = sampler(logits, sampling_metadata) @@ -144,7 +203,10 @@ def test_sampler_logprobs(): logits = _create_fake_logits(batch_size, vocab_size) sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0) for logprobs_mode in logprobs_mode_list: - sampler = Sampler(logprobs_mode=logprobs_mode) + fd_config = get_fd_config(batch_size) + fd_config.model_config.logprobs_mode = logprobs_mode + sampler = Sampler(logprobs_mode=logprobs_mode, fd_config=fd_config) + assert sampler.logprobs_mode == logprobs_mode sampler_output = sampler(logits.clone(), sampling_metadata) baseline_logprobs = get_baseline_logprobs( logits.clone(), sampling_metadata, logprobs_mode=logprobs_mode, token_ids=sampler_output.sampled_token_ids diff --git a/tests/woker/test_gpu_prompt_logprobs.py b/tests/woker/test_gpu_prompt_logprobs.py index 442223dd9..ba0b9d9fe 100644 --- a/tests/woker/test_gpu_prompt_logprobs.py +++ b/tests/woker/test_gpu_prompt_logprobs.py @@ -12,15 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import os import time import unittest import numpy as np import paddle +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + LoadConfig, + ModelConfig, + ParallelConfig, +) from fastdeploy.engine.request import Request from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.model_executor.layers.sample.sampler import Sampler +from fastdeploy.scheduler import SchedulerConfig from fastdeploy.worker.gpu_model_runner import GPUModelRunner @@ -82,6 +93,53 @@ class FakeModel: return paddle.matmul(x.astype("float32"), self.weight) +def build_config_json() -> str: + config_dict = { + "architectures": ["Qwen3MoeForCausalLM"], + "hidden_size": 7168, + "moe_intermediate_size": 1, + "moe_num_experts": 1, + "moe_k": 1, + "hidden_act": "silu", + "num_attention_heads": 64, + "dtype": "bfloat16", + } + + tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}" + os.makedirs(tmp_dir, exist_ok=True) + with open(f"./{tmp_dir}/config.json", "w") as f: + json.dump(config_dict, f) + model_name_or_path = os.path.join(os.getcwd(), tmp_dir) + print("model_name_or_path", model_name_or_path) + return model_name_or_path + + +def get_fd_config(batch_size: int): + fd_config = FDConfig( + model_config=ModelConfig( + { + "model": build_config_json(), + "max_model_len": 2048, + } + ), + parallel_config=ParallelConfig( + { + "tensor_parallel_size": 1, + "expert_parallel_size": 1, + "expert_parallel_rank": 0, + "data_parallel_size": 1, + } + ), + # quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]), + scheduler_config=SchedulerConfig({"max_num_seqs": batch_size}), + cache_config=CacheConfig({}), + graph_opt_config=GraphOptimizationConfig({}), + load_config=LoadConfig({}), + ips="0.0.0.0", + ) + return fd_config + + class TestGPUPromptLogprobs(unittest.TestCase): def setup_model_runner(self): """Helper method to setup GPUModelRunner with different configurations""" @@ -96,7 +154,7 @@ class TestGPUPromptLogprobs(unittest.TestCase): model_runner.ori_vocab_size = cfg.model_config.ori_vocab_size model_runner.share_inputs = {} model_runner.share_inputs["cu_seqlens_q"] = paddle.to_tensor([0, 1, 2, 3], dtype="int32") - model_runner.sampler = Sampler() + model_runner.sampler = Sampler(get_fd_config(batch_size=1)) model_runner.model = FakeModel(cfg.model_config.vocab_size, cfg.model_config.hidden_size)