[Optimization] xgrammar async compile, multi thread, speed up (#4835)

* xgrammar async compile, multi thread, speed up

* fix test_sampler.py & pre-commit err

* add redis version check && fix request.llm_engine_recv_req_timestamp

* xgrammar prefill & decode & v0

* fix test_gpu_prompt_logprobs.py

* add test_guided_decoding.py

* Update fastdeploy/scheduler/splitwise_scheduler.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix torch xgrammar unittest env

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Daci
2025-11-14 18:05:26 +08:00
committed by GitHub
parent b925533051
commit 5fc12eddfe
11 changed files with 810 additions and 373 deletions

View File

@@ -14,9 +14,10 @@
# limitations under the License.
"""
import multiprocessing
import os
import traceback
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from fastdeploy.config import ErnieArchitectures, FDConfig
from fastdeploy.engine.request import Request
@@ -135,9 +136,9 @@ class BackendBase:
"""
def __init__(self, fd_config: FDConfig):
self.cache = {}
self.fd_config = fd_config
self.executor = ThreadPoolExecutor()
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.max_cache_size = 2048
self.reasoning_parser = None
@@ -263,7 +264,7 @@ class BackendBase:
self,
schemata_key: tuple[str, str],
enable_thinking: bool = False,
) -> tuple[LogitsProcessorBase, bool]:
) -> Future[LogitsProcessorBase]:
"""
get logits processor by key from cache or create new one.
@@ -275,13 +276,8 @@ class BackendBase:
- LogitsProcessorBase: The logits processor instance
- bool: True if processor was from cache, False if newly created
"""
value = self.cache.get(schemata_key, None)
if value:
value_copy = value.copy()
value_copy.enable_reasoning = enable_thinking
return value_copy, True
value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
return value, False
return value
def _get_tokenizer_hf(self):
"""
@@ -303,7 +299,7 @@ class BackendBase:
tokenizer = AutoTokenizer.from_pretrained(
self.fd_config.model_config.model,
use_fast=False,
use_fast=True,
)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
@@ -334,21 +330,6 @@ class BackendBase:
except Exception as e:
raise Exception(f"Fail to initialize hf tokenizer: {e}, {str(traceback.format_exc())}")
def add_cache(self, schemata_key: tuple[str, str], processor: LogitsProcessorBase) -> None:
"""
add logits processor to cache.
Args:
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
processor (LogitsProcessorBase): Logits processor instance to cache
Returns:
None: No return value
"""
if len(self.cache) >= self.max_cache_size:
return
self.cache[schemata_key] = processor.copy()
class BaseChecker:
"""

View File

@@ -29,6 +29,7 @@ from fastdeploy.model_executor.guided_decoding import (
BaseChecker,
LogitsProcessorBase,
)
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger
try:
@@ -86,6 +87,8 @@ class XGrammarProcessor(LogitsProcessorBase):
terminate_without_stop_token=terminate_without_stop_token,
override_stop_tokens=override_stop_tokens,
)
# when matcher accept eos_token_id, is_terminated = True
self.is_terminated: bool = False
def allocate_token_bitmask(self) -> torch.Tensor:
"""
@@ -109,40 +112,6 @@ class XGrammarProcessor(LogitsProcessorBase):
"""
self.matcher.fill_next_token_bitmask(token_bitmask, idx)
def apply_token_mask(
self,
logits: paddle.Tensor,
token_bitmask: torch.Tensor,
indices: Optional[List[int]] = None,
) -> paddle.Tensor:
"""
Apply the token mask to the logits, modifying probabilities of invalid tokens.
Args:
logits (paddle.Tensor): The logits tensor to modify
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
Returns:
paddle.Tensor: The modified logits tensor
"""
origin_place = logits.place
origin_dtype = logits.dtype
logits = torch.from_numpy(logits.numpy())
logits = logits.float() # cpu
apply_token_bitmask_inplace(
logits=logits,
bitmask=token_bitmask.to(logits.device, non_blocking=True),
indices=indices,
)
return paddle.to_tensor(
logits.numpy(),
dtype=origin_dtype,
place=origin_place,
)
def reset(self) -> None:
"""
Reset the grammar matcher state to initial conditions.
@@ -155,23 +124,21 @@ class XGrammarProcessor(LogitsProcessorBase):
def accept_token(self, token: int) -> None:
"""
Validate and accept a generated token against the grammar constraints.
when accept eos_token, is_terminated = True
Args:
token (int): The token ID to validate
Raises:
AssertionError: If token is not allowed by the grammar
"""
assert self.matcher.accept_token(token), f"Failed to accept token {token}"
def is_terminated(self) -> bool:
"""
Check if the grammar matching process has terminated.
Returns:
bool: True if matching has terminated, False otherwise
"""
return self.matcher.is_terminated()
if self.is_terminated or self.matcher.is_terminated():
self.is_terminated = True
return False
if not self.matcher.accept_token(token):
self.matcher.reset()
return False
if self.matcher.is_terminated():
self.is_terminated = True
return True
def copy(self) -> "XGrammarProcessor":
"""
@@ -216,7 +183,18 @@ class XGrammarBackend(BackendBase):
try:
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
llm_logger.info(f"xgrammar_backend.py tokenizer_info={tokenizer_info.dump_metadata()}")
# Read configuration values, fallback to defaults if not set
xgrammar_cfg = getattr(fd_config, "xgrammar_config", {})
max_threads = getattr(xgrammar_cfg, "max_threads", 8)
cache_enabled = getattr(xgrammar_cfg, "cache_enabled", True)
cache_limit_bytes = getattr(xgrammar_cfg, "cache_limit_bytes", 4 * 1024 * 1024)
self.grammar_compiler = GrammarCompiler(
tokenizer_info=tokenizer_info,
max_threads=max_threads,
cache_enabled=cache_enabled,
cache_limit_bytes=cache_limit_bytes,
)
except Exception as e:
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
@@ -467,3 +445,49 @@ class XGrammarChecker(BaseChecker):
else:
# regex is not format
return request, None
def apply_token_mask(
logits: paddle.Tensor,
token_bitmask: torch.Tensor,
indices: Optional[List[int]] = None,
) -> paddle.Tensor:
"""
Apply the token mask to the logits, modifying probabilities of invalid tokens.
Args:
logits (paddle.Tensor): The logits tensor to modify
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
Returns:
paddle.Tensor: The modified logits tensor
"""
if current_platform.is_cuda():
dlpack = paddle.utils.dlpack.to_dlpack(logits)
t_logits = torch.from_dlpack(dlpack)
apply_token_bitmask_inplace(
logits=t_logits,
bitmask=token_bitmask.to(t_logits.device, non_blocking=True),
indices=indices,
)
dlpack2 = torch.utils.dlpack.to_dlpack(t_logits)
return paddle.utils.dlpack.from_dlpack(dlpack2)
else:
origin_place = logits.place
origin_dtype = logits.dtype
logits = torch.from_numpy(logits.numpy())
logits = logits.float() # cpu
apply_token_bitmask_inplace(
logits=logits,
bitmask=token_bitmask.to(logits.device, non_blocking=True),
indices=indices,
)
return paddle.to_tensor(
logits.numpy(),
dtype=origin_dtype,
place=origin_place,
)

View File

@@ -14,13 +14,14 @@
# limitations under the License.
"""
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
import time
from concurrent.futures import Future
from typing import Any, List, Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
@@ -66,140 +67,187 @@ class GuidedDecoding:
processor for guided decoding.
"""
def __init__(self):
self.async_step = None
def __init__(self, fd_config: FDConfig):
self.token_bitmask = None
self.logits_processor: Dict[int, Optional[Any]] = dict()
self.executor = ThreadPoolExecutor()
self.logits_lock = threading.Lock()
self.logits_processors: List[Any] = [None] * fd_config.scheduler_config.max_num_seqs
self.reasoning_parser = None
self._prefill_done_idxs: List[bool] = [False] * fd_config.scheduler_config.max_num_seqs
# for pd
self._tokens_to_acc: List[None | List[int]] = [None] * fd_config.scheduler_config.max_num_seqs
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
self.reasoning_parser = reasoning_parser
def add_logits_processor(
self,
ids: int,
idx: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""add logits processor to GuidedDecoding"""
with self.logits_lock:
if future is None:
if ids in self.logits_processor:
del self.logits_processor[ids]
return
"""add logits processor to SamplerProcessor"""
self._prefill_done_idxs[idx] = False
if isinstance(future, LogitsProcessorBase):
self.logits_processor[ids] = future
for token in prefill_tokens:
self.logits_processor[ids].accept_token(token)
elif future.done():
self.logits_processor[ids] = future.result()
for token in prefill_tokens:
self.logits_processor[ids].accept_token(token)
else:
self.logits_processor[ids] = [future, prefill_tokens]
def update_vocab_mask(self, skip_idx_list: List[int] = []):
"""update vocab mask. (cpu-heavy operation)"""
if len(self.logits_processor) == 0:
if future is None:
# normal request without guided_backend
self.logits_processors[idx] = None
return
with self.logits_lock:
for idx, processor in self.logits_processor.items():
if processor is None:
del self.logits_processor[idx]
continue
if len(prefill_tokens) != 0:
# first_token from prefill node
self._prefill_done_idxs[idx] = True
if not isinstance(processor, LogitsProcessorBase):
future, prefill_tokens = self.logits_processor[idx]
self.logits_processor[idx] = future.result()
for token in prefill_tokens:
self.logits_processor[idx].accept_token(token)
if future.done():
# cached xgrammar
self.logits_processors[idx] = future.result()
for token in prefill_tokens:
self._accept_token(idx, token)
else:
# async
self.logits_processors[idx] = future
self._tokens_to_acc[idx] = prefill_tokens
available_processors = None
for processor in self.logits_processor.values():
if processor.is_terminated():
continue
available_processors = processor
if available_processors is None:
return
def should_fill_bitmask(self, idx: int) -> bool:
"""
Determines whether to fill a bitmask for the logits processor at the given index.
# allocate token bitmask
self.token_bitmask = available_processors.allocate_token_bitmask()
Args:
idx (int): The index of the logits processor to check
with self.logits_lock:
# fill token bitmask
for idx, processor in self.logits_processor.items():
if processor.is_terminated() or idx in skip_idx_list:
continue
Returns:
bool: True if the idx request bitmask should be filled
"""
if self.reasoning_parser is not None:
if self.logits_processors[idx].enable_reasoning: # <think> guided
return True
if not self.logits_processors[idx].reasoning_ended:
return False
return True
def reset_processor(self, idx: int):
"""reset idx"""
self._prefill_done_idxs[idx] = False
self.logits_processors[idx] = None
def update_vocab_mask(self, prefill_done_idxs: List[int] = []):
"""update vocab mask. (cpu-heavy operation)"""
for idx in prefill_done_idxs:
if self.logits_processors[idx] is None:
continue
assert not self._prefill_done_idxs[idx]
self._prefill_done_idxs[idx] = True
if isinstance(self.logits_processors[idx], Future):
continue
for idx, processor in enumerate(self.logits_processors):
if processor is None or not self._prefill_done_idxs[idx]:
continue
# skip, join at apply_token_mask
if isinstance(processor, Future):
continue
if processor.is_terminated:
self.reset_processor(idx)
continue
self.accept_tokens_from_prefill_node(idx)
if self.token_bitmask is None:
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
if self.should_fill_bitmask(idx):
processor.fill_token_bitmask(self.token_bitmask, idx)
def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []):
"""apply token mask to logits"""
if len(self.logits_processor) == 0 or self.token_bitmask is None:
return logits
def accept_tokens_from_prefill_node(self, idx: int):
"""accept prefill token, not future"""
if self._tokens_to_acc[idx] is not None:
# accept token from prefill node first
for token in self._tokens_to_acc[idx]:
self._accept_token(idx, token)
self._tokens_to_acc[idx] = None
# self.async_step.result()
available_processors = None
with self.logits_lock:
for processor in self.logits_processor.values():
if processor.is_terminated():
continue
available_processors = processor
if available_processors is None:
return logits
def apply_token_mask(self, logits: paddle.Tensor, prefill_done_idxs: List[int] = []):
"""apply token mask to logits"""
indices = []
for idx, processor in self.logits_processor.items():
if processor is None or idx in skip_idx_list:
for idx, processor in enumerate(self.logits_processors):
if processor is None or not self._prefill_done_idxs[idx]:
continue
if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended:
indices.append(idx)
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices)
# compiled done, check idx should fill, fill_token_bitmask done in preprocess
if not isinstance(processor, Future):
if self.should_fill_bitmask(idx):
indices.append(idx)
continue
# is Future, processor async compiled not ready, need join and wait
ts = time.time()
wait = False
if not processor.done():
wait = True
self.logits_processors[idx] = processor.result()
if wait:
logger.debug(f"[{idx} join async compile xgrammar, time_cost:{time.time() - ts}]")
self.accept_tokens_from_prefill_node(idx)
# Possible optimization: Extract 'think' content validation from logits_processors,
# allowing join operations to complete immediately after 'think' terminates.
# Furthermore, the current idx could be skipped, with compilation overhead
# estimated at only a few milliseconds.
# check idx for fill_token_mask
if not self.should_fill_bitmask(idx):
continue
indices.append(idx)
if self.token_bitmask is None:
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
if len(indices) == 0:
return logits
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
apply_token_mask,
)
return apply_token_mask(logits, self.token_bitmask, indices=indices)
def _accept_token(self, idx: int, token: int):
"""accept token"""
if idx not in self.logits_processor:
raise ValueError(f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}")
if self.logits_processor[idx].is_terminated():
return
if self.reasoning_parser is not None:
if not self.logits_processors[idx].enable_reasoning:
if not self.logits_processors[idx].reasoning_ended:
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
self.logits_processors[idx].reasoning_ended = reasoning_ended
return
if (
self.reasoning_parser is not None
and self.logits_processor[idx].enable_reasoning
and not self.logits_processor[idx].reasoning_ended
):
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
self.logits_processor[idx].reasoning_ended = reasoning_ended
return
if not self.logits_processors[idx].accept_token(token) or self.logits_processors[idx].is_terminated:
self.reset_processor(idx)
self.logits_processor[idx].accept_token(token)
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
def update_output_tokens(self, next_tokens: paddle.Tensor):
"""update output tokens"""
if len(self.logits_processor) == 0:
if len(self.logits_processors) == 0:
return
token_ids = next_tokens.numpy().tolist()
with self.logits_lock:
for idx in self.logits_processor.keys():
token = token_ids[idx][0]
if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list:
continue
for idx, processor in enumerate(self.logits_processors):
if not self._prefill_done_idxs[idx] or processor is None:
continue
if idx >= len(token_ids):
continue
token = token_ids[idx][0]
if token < 0:
self.reset_processor(idx)
continue
logger.debug(f"[{idx}]accept token{token}")
self._accept_token(idx, token)
self._accept_token(idx, token)
def pre_process(self, skip_idx_list: List[int] = []):
def pre_process(self, prefill_done_idxs: List[int] = []):
"""pre process before running"""
# create async operation for guided decoding
# TODO: support async
self.update_vocab_mask(skip_idx_list)
# self.async_step = self.executor.submit(self.update_vocab_mask)
self.update_vocab_mask(prefill_done_idxs)
class Sampler(nn.Layer):
@@ -224,7 +272,7 @@ class Sampler(nn.Layer):
else:
raise NotImplementedError
self.guided_decoding = GuidedDecoding()
self.guided_decoding = GuidedDecoding(fd_config)
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
@@ -240,17 +288,19 @@ class Sampler(nn.Layer):
"""set reasoning parser"""
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
def apply_logits_processor(
self, ids: int, future: Future[LogitsProcessorBase] = None, prefill_tokens: List[int] = []
):
"""apply logits processor to sampler"""
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
def pre_process(self, skip_idx_list: List[int] = []):
def pre_process(self, prefill_done_idxs: List[int] = []):
"""pre process before running"""
self.guided_decoding.pre_process(skip_idx_list)
self.guided_decoding.pre_process(prefill_done_idxs)
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
def post_process(self, next_tokens: paddle.Tensor):
"""post process after running"""
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
self.guided_decoding.update_output_tokens(next_tokens)
def compute_logprobs(
self,
@@ -343,10 +393,10 @@ class Sampler(nn.Layer):
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
skip_idx_list: List[int] = [],
p_done_idxs: List[int] = [],
) -> SamplerOutput:
""" """
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
logits = self.guided_decoding.apply_token_mask(logits, p_done_idxs)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:

View File

@@ -706,10 +706,30 @@ class InferScheduler:
self.reqs_queue = deque()
self.writers = []
def check_redis_version(self):
# Get Redis version information
redis_info = self.client.info()
redis_version = redis_info.get("redis_version", "")
version_parts = [int(x) for x in redis_version.split(".")]
# Redis 6.2 and above versions support RPOP with count parameter
assert (
version_parts[0] >= 6
), f"Redis major version too low: {version_parts[0]}. Please upgrade to Redis 6.2+ to support batch RPOP operations."
assert (
version_parts[1] >= 2 if version_parts[0] == 6 else True
), f"Redis version {redis_version} too low. Please upgrade to Redis 6.2+ to support batch RPOP operations."
logger.info(f"Redis version {redis_version} detected. Using native batch RPOP.")
def start(self, role, host, disaggregated):
"""
start backup threads
"""
# Check Redis version first
self.check_redis_version()
for i in range(self.writer_parallel):
writer = ResultWriter(self.client, i, self.writer_batch_size, self.ttl)
writer.start()

View File

@@ -31,9 +31,6 @@ from fastdeploy.model_executor.graph_optimization.utils import (
sot_warmup_guard,
)
from fastdeploy.model_executor.guided_decoding import get_guided_backend
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
LogitsProcessorBase,
)
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
@@ -182,7 +179,7 @@ class GCUModelRunner(ModelRunnerBase):
or request.guided_grammar is not None
):
logits_info, schemata_key = self._init_logits_processor(request)
request.logits_processor, request.logits_cached = logits_info
request.logits_processor = logits_info
request.schemata_key = schemata_key
# Is Decode Node
@@ -925,29 +922,36 @@ class GCUModelRunner(ModelRunnerBase):
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int):
"""
Get the index of the request that needs to be skipped during execution.
Args:
model_forward_batch: A list of requests to be executed by this runner.
Returns:
A list of indices corresponding to the requests that need to be skipped.
Returns indices for guided decoding.
When Prefill is done, async compiled logits_processor must be joined.
"""
skip_idx_list = []
if not self.cache_config.enable_chunked_prefill or self.guided_backend is None:
return skip_idx_list
if self.guided_backend is None:
return []
for task in model_forward_batch:
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
prefill_done_idxs = []
for idx in range(0, num_running_requests):
if self.share_inputs["step_idx"][idx] == 0:
prefill_done_idxs.append(idx)
for task in self.restore_chunked_prefill_request.values():
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
if self.cache_config.enable_chunked_prefill:
if model_forward_batch is not None:
for task in model_forward_batch:
# new Request with ChunkPrefill, unfinished, store
if task.chunk_idx < len(task.prefill_chunk_info):
if task.request_id not in self.restore_chunked_prefill_request:
self.restore_chunked_prefill_request[task.request_id] = task
return skip_idx_list
for id, task in list(self.restore_chunked_prefill_request.items()):
# unfinished, remove
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
prefill_done_idxs.remove(task.idx)
# finished, add
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
prefill_done_idxs.append(task.idx)
return prefill_done_idxs
def execute_model(
self,
@@ -971,9 +975,9 @@ class GCUModelRunner(ModelRunnerBase):
return None
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
self.sampler.pre_process(p_done_idxs)
# 2. Padding inputs for cuda graph
@@ -1009,7 +1013,7 @@ class GCUModelRunner(ModelRunnerBase):
sampler_output = self.sampler(
logits,
self.sampling_metadata,
skip_idx_list,
p_done_idxs,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
@@ -1081,28 +1085,9 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
self.seq_lens_this_time_buffer.copy_(self.share_inputs["seq_lens_this_time"], False)
return None
def _add_cache(self, model_forward_batch) -> None:
"""
Add cache for guided decoding.
"""
if self.guided_backend is None:
return
for request in model_forward_batch:
logits_cached = request.get("logits_cached", None)
if logits_cached is None or logits_cached:
continue
request.logits_cached = True
if isinstance(request.logits_processor, LogitsProcessorBase):
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
else:
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
def _execute_empty_input(self) -> None:
"""
In certain scenarios, such as during EP,

View File

@@ -17,6 +17,7 @@
import os
import queue
import time
from concurrent.futures import Future
from threading import Thread
from typing import List, Optional, cast
@@ -287,7 +288,7 @@ class GPUModelRunner(ModelRunnerBase):
else:
self.proposer = None
def _init_logits_processor(self, request):
def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]:
"""
init logits processor for guided decoding
"""
@@ -307,7 +308,7 @@ class GPUModelRunner(ModelRunnerBase):
return (
self.guided_backend.get_logits_processor(
schemata_key=schemata_key,
enable_thinking=True,
enable_thinking=False, # TODO cfg
),
schemata_key,
)
@@ -696,6 +697,7 @@ class GPUModelRunner(ModelRunnerBase):
length = len(request.prompt_token_ids)
assert length > 0, "The prompt requested must not be empty."
logits_info = None
prefill_tokens = []
if (
request.guided_json is not None
@@ -704,7 +706,6 @@ class GPUModelRunner(ModelRunnerBase):
or request.guided_grammar is not None
):
logits_info, schemata_key = self._init_logits_processor(request)
request.logits_processor, request.logits_cached = logits_info
request.schemata_key = schemata_key
# Is Decode Node
@@ -874,7 +875,7 @@ class GPUModelRunner(ModelRunnerBase):
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens)
self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)
self.share_inputs["not_need_stop"][0] = True
@@ -2006,34 +2007,36 @@ class GPUModelRunner(ModelRunnerBase):
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int):
"""
Get the index of the request that needs to be skipped during execution.
Args:
model_forward_batch: A list of requests to be executed by this runner.
Returns:
A list of indices corresponding to the requests that need to be skipped.
Get indices for guided decoding.
When Prefill is done, async compiled logits_processor must be joined.
"""
if (
not self.cache_config.enable_chunked_prefill
or self.guided_backend is None
or model_forward_batch is None
or envs.ENABLE_V1_KVCACHE_SCHEDULER
):
if self.guided_backend is None:
return []
skip_idx_list = []
for task in model_forward_batch:
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
prefill_done_idxs = []
for idx in range(0, num_running_requests):
if self.share_inputs["step_idx"][idx] == 0:
prefill_done_idxs.append(idx)
for task in self.restore_chunked_prefill_request.values():
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
if self.cache_config.enable_chunked_prefill:
if model_forward_batch is not None:
for task in model_forward_batch:
# new Request with ChunkPrefill, unfinished, store
if task.chunk_idx < len(task.prefill_chunk_info):
if task.request_id not in self.restore_chunked_prefill_request:
self.restore_chunked_prefill_request[task.request_id] = task
return skip_idx_list
for id, task in list(self.restore_chunked_prefill_request.items()):
# unfinished, remove
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
prefill_done_idxs.remove(task.idx)
# finished, add
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
prefill_done_idxs.append(task.idx)
return prefill_done_idxs
def execute_model(
self,
@@ -2050,9 +2053,10 @@ class GPUModelRunner(ModelRunnerBase):
num_running_requests: batch_size
"""
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
self.sampler.pre_process(p_done_idxs)
# 1.1 Update state of logits processor
for proc in self.sampling_metadata.logits_processors:
@@ -2157,7 +2161,7 @@ class GPUModelRunner(ModelRunnerBase):
sampler_output = self.sampler(
logits,
self.sampling_metadata,
skip_idx_list,
p_done_idxs,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
@@ -2244,7 +2248,7 @@ class GPUModelRunner(ModelRunnerBase):
line_break_id=self.model_config.line_break_id,
)
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
self.sampler.post_process(sampler_output.sampled_token_ids)
# 6. Speculative decode
if self.speculative_decoding:
@@ -2268,7 +2272,6 @@ class GPUModelRunner(ModelRunnerBase):
)
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
elif self.speculative_decoding:
speculate_schedule_cache(
self.share_inputs["draft_tokens"],
@@ -2325,24 +2328,6 @@ class GPUModelRunner(ModelRunnerBase):
return pooler_output
def _add_cache(self, model_forward_batch) -> None:
"""
Add cache for guided decoding.
"""
if self.guided_backend is None or model_forward_batch is None:
return
for request in model_forward_batch:
logits_cached = request.get("logits_cached", None)
if logits_cached is None or logits_cached:
continue
request.logits_cached = True
if isinstance(request.logits_processor, LogitsProcessorBase):
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
else:
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
def _execute_empty_input(self) -> None:
"""
In certain scenarios, such as during EP,

View File

@@ -30,9 +30,6 @@ from fastdeploy.engine.request import Request
# from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy.model_executor.forward_meta import HPUForwardMeta
from fastdeploy.model_executor.guided_decoding import get_guided_backend
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
LogitsProcessorBase,
)
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
@@ -442,7 +439,7 @@ class HPUModelRunner(ModelRunnerBase):
or request.guided_grammar is not None
):
logits_info, schemata_key = self._init_logits_processor(request)
request.logits_processor, request.logits_cached = logits_info
request.logits_processor = logits_info
request.schemata_key = schemata_key
# Is Decode Node
@@ -1179,30 +1176,6 @@ class HPUModelRunner(ModelRunnerBase):
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
def _get_skip_idx(self, model_forward_batch):
"""
Get the index of the request that needs to be skipped during execution.
Args:
model_forward_batch: A list of requests to be executed by this runner.
Returns:
A list of indices corresponding to the requests that need to be skipped.
"""
skip_idx_list = []
if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None:
return skip_idx_list
for task in model_forward_batch:
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
for task in self.restore_chunked_prefill_request.values():
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
return skip_idx_list
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
@@ -1332,30 +1305,11 @@ class HPUModelRunner(ModelRunnerBase):
execution_time = (end_time - start_time) * 1000
hpu_model_runner_profile_logger.info(f"StepPaddle execution time(ms): {execution_time}, BT={real_bs}")
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
if int(os.environ.get("HABANA_PROFILE", 0)) == 1:
self.prof.step()
return None
def _add_cache(self, model_forward_batch) -> None:
"""
Add cache for guided decoding.
"""
if self.guided_backend is None:
return
for request in model_forward_batch:
logits_cached = request.get("logits_cached", None)
if logits_cached is None or logits_cached:
continue
request.logits_cached = True
if isinstance(request.logits_processor, LogitsProcessorBase):
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
else:
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
def _execute_empty_input(self) -> None:
"""
In certain scenarios, such as during EP,

View File

@@ -38,10 +38,7 @@ from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
sot_warmup_guard,
)
from fastdeploy.model_executor.guided_decoding import (
LogitsProcessorBase,
get_guided_backend,
)
from fastdeploy.model_executor.guided_decoding import get_guided_backend
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
@@ -507,7 +504,7 @@ class MetaxModelRunner(ModelRunnerBase):
or request.guided_grammar is not None
):
logits_info, schemata_key = self._init_logits_processor(request)
request.logits_processor, request.logits_cached = logits_info
request.logits_processor = logits_info
request.schemata_key = schemata_key
# Is Decode Node
@@ -1761,34 +1758,36 @@ class MetaxModelRunner(ModelRunnerBase):
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int):
"""
Get the index of the request that needs to be skipped during execution.
Args:
model_forward_batch: A list of requests to be executed by this runner.
Returns:
A list of indices corresponding to the requests that need to be skipped.
Get indices for guided decoding.
When Prefill is done, async compiled logits_processor must be joined.
"""
if (
not self.cache_config.enable_chunked_prefill
or self.guided_backend is None
or model_forward_batch is None
or envs.ENABLE_V1_KVCACHE_SCHEDULER
):
if self.guided_backend is None:
return []
skip_idx_list = []
for task in model_forward_batch:
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
prefill_done_idxs = []
for idx in range(0, num_running_requests):
if self.share_inputs["step_idx"][idx] == 0:
prefill_done_idxs.append(idx)
for task in self.restore_chunked_prefill_request.values():
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
if self.cache_config.enable_chunked_prefill:
if model_forward_batch is not None:
for task in model_forward_batch:
# new Request with ChunkPrefill, unfinished, store
if task.chunk_idx < len(task.prefill_chunk_info):
if task.request_id not in self.restore_chunked_prefill_request:
self.restore_chunked_prefill_request[task.request_id] = task
return skip_idx_list
for id, task in list(self.restore_chunked_prefill_request.items()):
# unfinished, remove
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
prefill_done_idxs.remove(task.idx)
# finished, add
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
prefill_done_idxs.append(task.idx)
return prefill_done_idxs
def execute_model(
self,
@@ -1805,9 +1804,9 @@ class MetaxModelRunner(ModelRunnerBase):
num_running_requests: batch_size
"""
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
self.sampler.pre_process(p_done_idxs)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
@@ -1864,7 +1863,7 @@ class MetaxModelRunner(ModelRunnerBase):
sampler_output = self.sampler(
logits,
self.sampling_metadata,
skip_idx_list,
p_done_idxs,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
@@ -1949,7 +1948,7 @@ class MetaxModelRunner(ModelRunnerBase):
line_break_id=self.model_config.line_break_id,
)
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
self.sampler.post_process(sampler_output.sampled_token_ids)
# 6. Speculative decode
if self.speculative_decoding:
@@ -1973,7 +1972,6 @@ class MetaxModelRunner(ModelRunnerBase):
)
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
elif self.speculative_decoding:
speculate_schedule_cache(
self.share_inputs["draft_tokens"],
@@ -2000,24 +1998,6 @@ class MetaxModelRunner(ModelRunnerBase):
)
return None
def _add_cache(self, model_forward_batch) -> None:
"""
Add cache for guided decoding.
"""
if self.guided_backend is None or model_forward_batch is None:
return
for request in model_forward_batch:
logits_cached = request.get("logits_cached", None)
if logits_cached is None or logits_cached:
continue
request.logits_cached = True
if isinstance(request.logits_processor, LogitsProcessorBase):
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
else:
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
def _execute_empty_input(self) -> None:
"""
In certain scenarios, such as during EP,

View File

@@ -0,0 +1,338 @@
"""
测试GuidedDecoding类的单元测试
"""
import sys
import unittest
from concurrent.futures import Future
from unittest.mock import MagicMock, Mock, patch
import paddle
mock_torch = MagicMock()
mock_xgrammar = MagicMock()
sys.modules["torch"] = mock_torch
sys.modules["xgrammar"] = mock_xgrammar
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
from fastdeploy.model_executor.layers.sample.sampler import GuidedDecoding
from fastdeploy.reasoning import ReasoningParser
class TestGuidedDecoding(unittest.TestCase):
"""Test cases for GuidedDecoding class."""
def setUp(self):
"""Setup for each test case."""
# 创建一个基本的FDConfig对象
self.fd_config = Mock()
self.fd_config.scheduler_config.max_num_seqs = 5
# 创建GuidedDecoding对象
self.guided_decoding = GuidedDecoding(self.fd_config)
# 创建一个模拟的LogitsProcessorBase
self.mock_processor = Mock(spec=LogitsProcessorBase)
self.mock_processor.is_terminated = False
self.mock_processor.reasoning_ended = True
self.mock_processor.enable_reasoning = False
# 模拟allocate_token_bitmask方法返回一个假的bitmask
self.mock_processor.allocate_token_bitmask.return_value = paddle.zeros([5, 10], dtype="int32")
# 模拟fill_token_bitmask方法
self.mock_processor.fill_token_bitmask.return_value = None
# 模拟accept_token方法返回True
self.mock_processor.accept_token.return_value = True
def test_init(self):
"""Test initialization."""
self.assertIsNone(self.guided_decoding.token_bitmask)
self.assertEqual(len(self.guided_decoding.logits_processors), 5)
self.assertIsNone(self.guided_decoding.reasoning_parser)
self.assertEqual(len(self.guided_decoding._prefill_done_idxs), 5)
self.assertEqual(len(self.guided_decoding._tokens_to_acc), 5)
def test_apply_reasoning_parser(self):
"""Test apply_reasoning_parser method."""
mock_parser = Mock(spec=ReasoningParser)
self.guided_decoding.apply_reasoning_parser(mock_parser)
self.assertEqual(self.guided_decoding.reasoning_parser, mock_parser)
def test_add_logits_processor_no_future(self):
"""Test add_logits_processor method without future."""
self.guided_decoding.add_logits_processor(0, None, [])
self.assertFalse(self.guided_decoding._prefill_done_idxs[0])
self.assertIsNone(self.guided_decoding.logits_processors[0])
def test_add_logits_processor_with_prefill_tokens(self):
"""Test add_logits_processor method with prefill tokens."""
# 创建模拟Future对象
mock_future = Mock()
mock_future.done.return_value = True
mock_future.result.return_value = self.mock_processor
prefill_tokens = [1, 2, 3]
self.guided_decoding.add_logits_processor(0, mock_future, prefill_tokens)
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
self.assertEqual(self.guided_decoding.logits_processors[0], self.mock_processor)
self.mock_processor.accept_token.assert_any_call(1)
self.mock_processor.accept_token.assert_any_call(2)
self.mock_processor.accept_token.assert_any_call(3)
def test_add_logits_processor_with_async_future(self):
"""Test add_logits_processor method with async future."""
# 创建模拟Future对象
mock_future = Mock()
mock_future.done.return_value = False
prefill_tokens = [1, 2, 3]
self.guided_decoding.add_logits_processor(0, mock_future, prefill_tokens)
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
self.assertEqual(self.guided_decoding.logits_processors[0], mock_future)
self.assertEqual(self.guided_decoding._tokens_to_acc[0], prefill_tokens)
def test_should_fill_bitmask_no_reasoning_parser(self):
"""Test should_fill_bitmask method with no reasoning parser."""
self.guided_decoding.logits_processors[0] = self.mock_processor
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
def test_should_fill_bitmask_with_reasoning_parser(self):
"""Test should_fill_bitmask method with reasoning parser."""
mock_parser = Mock(spec=ReasoningParser)
self.guided_decoding.reasoning_parser = mock_parser
# 测试 enable_reasoning=True 的情况
self.mock_processor.enable_reasoning = True
self.guided_decoding.logits_processors[0] = self.mock_processor
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
# 测试 enable_reasoning=False, reasoning_ended=False 的情况
self.mock_processor.enable_reasoning = False
self.mock_processor.reasoning_ended = False
self.assertFalse(self.guided_decoding.should_fill_bitmask(0))
# 测试 enable_reasoning=False, reasoning_ended=True 的情况
self.mock_processor.reasoning_ended = True
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
def test_reset_processor(self):
"""Test reset_processor method."""
self.guided_decoding.logits_processors[0] = self.mock_processor
self.guided_decoding._prefill_done_idxs[0] = True
self.guided_decoding.reset_processor(0)
self.assertFalse(self.guided_decoding._prefill_done_idxs[0])
self.assertIsNone(self.guided_decoding.logits_processors[0])
def test_update_vocab_mask_with_new_prefill_done(self):
"""Test update_vocab_mask method with new prefill_done_idxs."""
# 设置索引0的处理器
self.guided_decoding.logits_processors[0] = self.mock_processor
self.guided_decoding._prefill_done_idxs[0] = False
# 调用update_vocab_mask并标记索引0为已完成
self.guided_decoding.update_vocab_mask([0])
# 验证_prefill_done_idxs[0]已更新
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
# 验证fill_token_bitmask被调用
self.mock_processor.fill_token_bitmask.assert_called_once()
def test_update_vocab_mask_with_future_processor(self):
"""Test update_vocab_mask method with future processor."""
# 创建模拟Future对象
mock_future = Mock()
# 设置索引0的处理器为Future
self.guided_decoding.logits_processors[0] = mock_future
self.guided_decoding._prefill_done_idxs[0] = True
# 调用update_vocab_mask
self.guided_decoding.update_vocab_mask([])
# 验证fill_token_bitmask没有被调用因为处理器是Future
self.mock_processor.fill_token_bitmask.assert_not_called()
def test_accept_tokens_from_prefill_node(self):
"""Test accept_tokens_from_prefill_node method."""
# 设置索引0的处理器和待接受的tokens
self.guided_decoding.logits_processors[0] = self.mock_processor
self.guided_decoding._tokens_to_acc[0] = [1, 2, 3]
# 调用accept_tokens_from_prefill_node
self.guided_decoding.accept_tokens_from_prefill_node(0)
# 验证accept_token被调用了3次
self.assertEqual(self.mock_processor.accept_token.call_count, 3)
self.mock_processor.accept_token.assert_any_call(1)
self.mock_processor.accept_token.assert_any_call(2)
self.mock_processor.accept_token.assert_any_call(3)
# 验证_tokens_to_acc[0]已被重置
self.assertIsNone(self.guided_decoding._tokens_to_acc[0])
@patch("fastdeploy.model_executor.guided_decoding.xgrammar_backend.apply_token_mask")
def test_apply_token_mask(self, mock_apply_token_mask):
"""Test apply_token_mask method."""
# 创建测试数据
logits = paddle.zeros([5, 10], dtype="float32")
mock_apply_token_mask.return_value = paddle.ones([5, 10], dtype="float32")
# 设置索引0的处理器
self.guided_decoding.logits_processors[0] = self.mock_processor
self.guided_decoding._prefill_done_idxs[0] = True
# 调用apply_token_mask
result = self.guided_decoding.apply_token_mask(logits, [])
# 验证fill_token_bitmask没有被调用非 Future
self.mock_processor.fill_token_bitmask.assert_not_called()
# 验证apply_token_mask被调用
mock_apply_token_mask.assert_called_once()
# 验证返回值
self.assertTrue((result == paddle.ones([5, 10], dtype="float32")).all())
def test_apply_token_mask_with_future_processor(self):
"""Test apply_token_mask method with future processor."""
# 创建测试数据
logits = paddle.zeros([5, 10], dtype="float32")
# 创建模拟Future对象
mock_future = Mock(spec=Future)
mock_future.done.return_value = True
mock_future.result.return_value = self.mock_processor
# 设置索引0的处理器为Future
self.guided_decoding.logits_processors[0] = mock_future
self.guided_decoding._prefill_done_idxs[0] = True
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
self.assertIsNotNone(self.guided_decoding.logits_processors[0])
self.assertTrue(isinstance(self.guided_decoding.logits_processors[0], Future))
self.guided_decoding._tokens_to_acc[0] = [1, 2, 3]
# 模拟patch apply_token_mask
with patch(
"fastdeploy.model_executor.guided_decoding.xgrammar_backend.apply_token_mask"
) as mock_apply_token_mask:
mock_apply_token_mask.return_value = paddle.ones([5, 10], dtype="float32")
# 调用apply_token_mask
self.guided_decoding.apply_token_mask(logits, [])
# 验证Future.result被调用
mock_future.result.assert_called_once()
# 验证accept_token被调用了3次
self.assertEqual(self.mock_processor.accept_token.call_count, 3)
# 验证_tokens_to_acc[0]已被重置
self.assertIsNone(self.guided_decoding._tokens_to_acc[0])
def test_accept_token(self):
"""Test _accept_token method."""
# 设置索引0的处理器
self.guided_decoding.logits_processors[0] = self.mock_processor
# 调用_accept_token
self.guided_decoding._accept_token(0, 1)
# 验证accept_token被调用
self.mock_processor.accept_token.assert_called_once_with(1)
def test_accept_token_with_reasoning_parser(self):
"""Test _accept_token method with reasoning parser."""
# 创建模拟ReasoningParser
mock_parser = Mock(spec=ReasoningParser)
mock_parser.is_reasoning_end.return_value = True
self.guided_decoding.reasoning_parser = mock_parser
# 设置索引0的处理器
self.mock_processor.enable_reasoning = False
self.mock_processor.reasoning_ended = False
self.guided_decoding.logits_processors[0] = self.mock_processor
# 调用_accept_token
self.guided_decoding._accept_token(0, 1)
# 验证is_reasoning_end被调用
mock_parser.is_reasoning_end.assert_called_once_with([1])
# 验证reasoning_ended已更新
self.assertTrue(self.mock_processor.reasoning_ended)
# 验证accept_token没有被调用因为reasoning_ended刚被设置为True
self.mock_processor.accept_token.assert_not_called()
def test_accept_token_processor_terminated(self):
"""Test _accept_token method when processor is terminated."""
# 设置索引0的处理器并让accept_token返回False
self.mock_processor.accept_token.return_value = False
self.guided_decoding.logits_processors[0] = self.mock_processor
# 调用_accept_token
self.guided_decoding._accept_token(0, 1)
# 验证处理器被重置
self.assertIsNone(self.guided_decoding.logits_processors[0])
def test_update_output_tokens(self):
"""Test update_output_tokens method."""
# 创建测试数据
next_tokens = paddle.to_tensor([[1], [2], [3], [4], [5]])
# 设置索引0和1的处理器
self.guided_decoding.logits_processors[0] = self.mock_processor
self.guided_decoding.logits_processors[1] = self.mock_processor
self.guided_decoding._prefill_done_idxs[0] = True
self.guided_decoding._prefill_done_idxs[1] = True
# 调用update_output_tokens
self.guided_decoding.update_output_tokens(next_tokens)
# 验证accept_token被调用了两次
self.assertEqual(self.mock_processor.accept_token.call_count, 2)
self.mock_processor.accept_token.assert_any_call(1)
self.mock_processor.accept_token.assert_any_call(2)
def test_update_output_tokens_with_negative_token(self):
"""Test update_output_tokens method with negative token."""
# 创建测试数据,包含负值
next_tokens = paddle.to_tensor([[-1], [2]])
# 设置索引0和1的处理器
self.guided_decoding.logits_processors[0] = self.mock_processor
self.guided_decoding.logits_processors[1] = self.mock_processor
self.guided_decoding._prefill_done_idxs[0] = True
self.guided_decoding._prefill_done_idxs[1] = True
# 调用update_output_tokens
self.guided_decoding.update_output_tokens(next_tokens)
# 验证索引0的处理器被重置
self.assertIsNone(self.guided_decoding.logits_processors[0])
# 验证索引1的处理器的accept_token被调用
self.mock_processor.accept_token.assert_called_once_with(2)
def test_pre_process(self):
"""Test pre_process method."""
# 模拟update_vocab_mask方法
with patch.object(self.guided_decoding, "update_vocab_mask") as mock_update_vocab_mask:
# 调用pre_process
self.guided_decoding.pre_process([0, 1])
# 验证update_vocab_mask被调用
mock_update_vocab_mask.assert_called_once_with([0, 1])
if __name__ == "__main__":
unittest.main()

View File

@@ -14,11 +14,23 @@
# limitations under the License.
"""
import json
import os
import paddle
import paddle.nn.functional as F
from fastdeploy.config import (
CacheConfig,
FDConfig,
GraphOptimizationConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.scheduler import SchedulerConfig
def _create_fake_logits(batch_size: int, vocab_size: int) -> paddle.Tensor:
@@ -67,13 +79,60 @@ def _create_default_sampling_metadata(
return fake_sampling_metadata
def build_config_json() -> str:
config_dict = {
"architectures": ["Qwen3MoeForCausalLM"],
"hidden_size": 7168,
"moe_intermediate_size": 1,
"moe_num_experts": 1,
"moe_k": 1,
"hidden_act": "silu",
"num_attention_heads": 64,
"dtype": "bfloat16",
}
tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}"
os.makedirs(tmp_dir, exist_ok=True)
with open(f"./{tmp_dir}/config.json", "w") as f:
json.dump(config_dict, f)
model_name_or_path = os.path.join(os.getcwd(), tmp_dir)
print("model_name_or_path", model_name_or_path)
return model_name_or_path
def get_fd_config(batch_size: int):
fd_config = FDConfig(
model_config=ModelConfig(
{
"model": build_config_json(),
"max_model_len": 2048,
}
),
parallel_config=ParallelConfig(
{
"tensor_parallel_size": 1,
"expert_parallel_size": 1,
"expert_parallel_rank": 0,
"data_parallel_size": 1,
}
),
# quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]),
scheduler_config=SchedulerConfig({"max_num_seqs": batch_size}),
cache_config=CacheConfig({}),
graph_opt_config=GraphOptimizationConfig({}),
load_config=LoadConfig({}),
ips="0.0.0.0",
)
return fd_config
def test_sampler():
batch_size = 32
vocab_size = 1024
min_seq_len = 1
max_seq_len = 1024
sampler = Sampler()
sampler = Sampler(get_fd_config(batch_size))
logits = _create_fake_logits(batch_size, vocab_size)
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
next_tokens = sampler(logits, sampling_metadata)
@@ -144,7 +203,10 @@ def test_sampler_logprobs():
logits = _create_fake_logits(batch_size, vocab_size)
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
for logprobs_mode in logprobs_mode_list:
sampler = Sampler(logprobs_mode=logprobs_mode)
fd_config = get_fd_config(batch_size)
fd_config.model_config.logprobs_mode = logprobs_mode
sampler = Sampler(logprobs_mode=logprobs_mode, fd_config=fd_config)
assert sampler.logprobs_mode == logprobs_mode
sampler_output = sampler(logits.clone(), sampling_metadata)
baseline_logprobs = get_baseline_logprobs(
logits.clone(), sampling_metadata, logprobs_mode=logprobs_mode, token_ids=sampler_output.sampled_token_ids

View File

@@ -12,15 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import time
import unittest
import numpy as np
import paddle
from fastdeploy.config import (
CacheConfig,
FDConfig,
GraphOptimizationConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
)
from fastdeploy.engine.request import Request
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
@@ -82,6 +93,53 @@ class FakeModel:
return paddle.matmul(x.astype("float32"), self.weight)
def build_config_json() -> str:
config_dict = {
"architectures": ["Qwen3MoeForCausalLM"],
"hidden_size": 7168,
"moe_intermediate_size": 1,
"moe_num_experts": 1,
"moe_k": 1,
"hidden_act": "silu",
"num_attention_heads": 64,
"dtype": "bfloat16",
}
tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}"
os.makedirs(tmp_dir, exist_ok=True)
with open(f"./{tmp_dir}/config.json", "w") as f:
json.dump(config_dict, f)
model_name_or_path = os.path.join(os.getcwd(), tmp_dir)
print("model_name_or_path", model_name_or_path)
return model_name_or_path
def get_fd_config(batch_size: int):
fd_config = FDConfig(
model_config=ModelConfig(
{
"model": build_config_json(),
"max_model_len": 2048,
}
),
parallel_config=ParallelConfig(
{
"tensor_parallel_size": 1,
"expert_parallel_size": 1,
"expert_parallel_rank": 0,
"data_parallel_size": 1,
}
),
# quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]),
scheduler_config=SchedulerConfig({"max_num_seqs": batch_size}),
cache_config=CacheConfig({}),
graph_opt_config=GraphOptimizationConfig({}),
load_config=LoadConfig({}),
ips="0.0.0.0",
)
return fd_config
class TestGPUPromptLogprobs(unittest.TestCase):
def setup_model_runner(self):
"""Helper method to setup GPUModelRunner with different configurations"""
@@ -96,7 +154,7 @@ class TestGPUPromptLogprobs(unittest.TestCase):
model_runner.ori_vocab_size = cfg.model_config.ori_vocab_size
model_runner.share_inputs = {}
model_runner.share_inputs["cu_seqlens_q"] = paddle.to_tensor([0, 1, 2, 3], dtype="int32")
model_runner.sampler = Sampler()
model_runner.sampler = Sampler(get_fd_config(batch_size=1))
model_runner.model = FakeModel(cfg.model_config.vocab_size, cfg.model_config.hidden_size)