mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Optimization] xgrammar async compile, multi thread, speed up (#4835)
* xgrammar async compile, multi thread, speed up * fix test_sampler.py & pre-commit err * add redis version check && fix request.llm_engine_recv_req_timestamp * xgrammar prefill & decode & v0 * fix test_gpu_prompt_logprobs.py * add test_guided_decoding.py * Update fastdeploy/scheduler/splitwise_scheduler.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix torch xgrammar unittest env --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -14,9 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from fastdeploy.config import ErnieArchitectures, FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
@@ -135,9 +136,9 @@ class BackendBase:
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
self.cache = {}
|
||||
self.fd_config = fd_config
|
||||
self.executor = ThreadPoolExecutor()
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.max_cache_size = 2048
|
||||
self.reasoning_parser = None
|
||||
|
||||
@@ -263,7 +264,7 @@ class BackendBase:
|
||||
self,
|
||||
schemata_key: tuple[str, str],
|
||||
enable_thinking: bool = False,
|
||||
) -> tuple[LogitsProcessorBase, bool]:
|
||||
) -> Future[LogitsProcessorBase]:
|
||||
"""
|
||||
get logits processor by key from cache or create new one.
|
||||
|
||||
@@ -275,13 +276,8 @@ class BackendBase:
|
||||
- LogitsProcessorBase: The logits processor instance
|
||||
- bool: True if processor was from cache, False if newly created
|
||||
"""
|
||||
value = self.cache.get(schemata_key, None)
|
||||
if value:
|
||||
value_copy = value.copy()
|
||||
value_copy.enable_reasoning = enable_thinking
|
||||
return value_copy, True
|
||||
value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
|
||||
return value, False
|
||||
return value
|
||||
|
||||
def _get_tokenizer_hf(self):
|
||||
"""
|
||||
@@ -303,7 +299,7 @@ class BackendBase:
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.fd_config.model_config.model,
|
||||
use_fast=False,
|
||||
use_fast=True,
|
||||
)
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
@@ -334,21 +330,6 @@ class BackendBase:
|
||||
except Exception as e:
|
||||
raise Exception(f"Fail to initialize hf tokenizer: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
def add_cache(self, schemata_key: tuple[str, str], processor: LogitsProcessorBase) -> None:
|
||||
"""
|
||||
add logits processor to cache.
|
||||
|
||||
Args:
|
||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||
processor (LogitsProcessorBase): Logits processor instance to cache
|
||||
|
||||
Returns:
|
||||
None: No return value
|
||||
"""
|
||||
if len(self.cache) >= self.max_cache_size:
|
||||
return
|
||||
self.cache[schemata_key] = processor.copy()
|
||||
|
||||
|
||||
class BaseChecker:
|
||||
"""
|
||||
|
||||
@@ -29,6 +29,7 @@ from fastdeploy.model_executor.guided_decoding import (
|
||||
BaseChecker,
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
try:
|
||||
@@ -86,6 +87,8 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
terminate_without_stop_token=terminate_without_stop_token,
|
||||
override_stop_tokens=override_stop_tokens,
|
||||
)
|
||||
# when matcher accept eos_token_id, is_terminated = True
|
||||
self.is_terminated: bool = False
|
||||
|
||||
def allocate_token_bitmask(self) -> torch.Tensor:
|
||||
"""
|
||||
@@ -109,40 +112,6 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
"""
|
||||
self.matcher.fill_next_token_bitmask(token_bitmask, idx)
|
||||
|
||||
def apply_token_mask(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
token_bitmask: torch.Tensor,
|
||||
indices: Optional[List[int]] = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the token mask to the logits, modifying probabilities of invalid tokens.
|
||||
|
||||
Args:
|
||||
logits (paddle.Tensor): The logits tensor to modify
|
||||
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
|
||||
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: The modified logits tensor
|
||||
"""
|
||||
origin_place = logits.place
|
||||
origin_dtype = logits.dtype
|
||||
logits = torch.from_numpy(logits.numpy())
|
||||
|
||||
logits = logits.float() # cpu
|
||||
apply_token_bitmask_inplace(
|
||||
logits=logits,
|
||||
bitmask=token_bitmask.to(logits.device, non_blocking=True),
|
||||
indices=indices,
|
||||
)
|
||||
|
||||
return paddle.to_tensor(
|
||||
logits.numpy(),
|
||||
dtype=origin_dtype,
|
||||
place=origin_place,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the grammar matcher state to initial conditions.
|
||||
@@ -155,23 +124,21 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
Validate and accept a generated token against the grammar constraints.
|
||||
when accept eos_token, is_terminated = True
|
||||
|
||||
Args:
|
||||
token (int): The token ID to validate
|
||||
|
||||
Raises:
|
||||
AssertionError: If token is not allowed by the grammar
|
||||
"""
|
||||
assert self.matcher.accept_token(token), f"Failed to accept token {token}"
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
Check if the grammar matching process has terminated.
|
||||
|
||||
Returns:
|
||||
bool: True if matching has terminated, False otherwise
|
||||
"""
|
||||
return self.matcher.is_terminated()
|
||||
if self.is_terminated or self.matcher.is_terminated():
|
||||
self.is_terminated = True
|
||||
return False
|
||||
if not self.matcher.accept_token(token):
|
||||
self.matcher.reset()
|
||||
return False
|
||||
if self.matcher.is_terminated():
|
||||
self.is_terminated = True
|
||||
return True
|
||||
|
||||
def copy(self) -> "XGrammarProcessor":
|
||||
"""
|
||||
@@ -216,7 +183,18 @@ class XGrammarBackend(BackendBase):
|
||||
|
||||
try:
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||
llm_logger.info(f"xgrammar_backend.py tokenizer_info={tokenizer_info.dump_metadata()}")
|
||||
# Read configuration values, fallback to defaults if not set
|
||||
xgrammar_cfg = getattr(fd_config, "xgrammar_config", {})
|
||||
max_threads = getattr(xgrammar_cfg, "max_threads", 8)
|
||||
cache_enabled = getattr(xgrammar_cfg, "cache_enabled", True)
|
||||
cache_limit_bytes = getattr(xgrammar_cfg, "cache_limit_bytes", 4 * 1024 * 1024)
|
||||
self.grammar_compiler = GrammarCompiler(
|
||||
tokenizer_info=tokenizer_info,
|
||||
max_threads=max_threads,
|
||||
cache_enabled=cache_enabled,
|
||||
cache_limit_bytes=cache_limit_bytes,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
|
||||
|
||||
@@ -467,3 +445,49 @@ class XGrammarChecker(BaseChecker):
|
||||
else:
|
||||
# regex is not format
|
||||
return request, None
|
||||
|
||||
|
||||
def apply_token_mask(
|
||||
logits: paddle.Tensor,
|
||||
token_bitmask: torch.Tensor,
|
||||
indices: Optional[List[int]] = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the token mask to the logits, modifying probabilities of invalid tokens.
|
||||
|
||||
Args:
|
||||
logits (paddle.Tensor): The logits tensor to modify
|
||||
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
|
||||
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: The modified logits tensor
|
||||
"""
|
||||
|
||||
if current_platform.is_cuda():
|
||||
dlpack = paddle.utils.dlpack.to_dlpack(logits)
|
||||
t_logits = torch.from_dlpack(dlpack)
|
||||
apply_token_bitmask_inplace(
|
||||
logits=t_logits,
|
||||
bitmask=token_bitmask.to(t_logits.device, non_blocking=True),
|
||||
indices=indices,
|
||||
)
|
||||
dlpack2 = torch.utils.dlpack.to_dlpack(t_logits)
|
||||
return paddle.utils.dlpack.from_dlpack(dlpack2)
|
||||
else:
|
||||
origin_place = logits.place
|
||||
origin_dtype = logits.dtype
|
||||
logits = torch.from_numpy(logits.numpy())
|
||||
|
||||
logits = logits.float() # cpu
|
||||
apply_token_bitmask_inplace(
|
||||
logits=logits,
|
||||
bitmask=token_bitmask.to(logits.device, non_blocking=True),
|
||||
indices=indices,
|
||||
)
|
||||
|
||||
return paddle.to_tensor(
|
||||
logits.numpy(),
|
||||
dtype=origin_dtype,
|
||||
place=origin_place,
|
||||
)
|
||||
|
||||
@@ -14,13 +14,14 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
|
||||
@@ -66,140 +67,187 @@ class GuidedDecoding:
|
||||
processor for guided decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.async_step = None
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
self.token_bitmask = None
|
||||
self.logits_processor: Dict[int, Optional[Any]] = dict()
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.logits_lock = threading.Lock()
|
||||
self.logits_processors: List[Any] = [None] * fd_config.scheduler_config.max_num_seqs
|
||||
self.reasoning_parser = None
|
||||
self._prefill_done_idxs: List[bool] = [False] * fd_config.scheduler_config.max_num_seqs
|
||||
# for pd
|
||||
self._tokens_to_acc: List[None | List[int]] = [None] * fd_config.scheduler_config.max_num_seqs
|
||||
|
||||
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||
self.reasoning_parser = reasoning_parser
|
||||
|
||||
def add_logits_processor(
|
||||
self,
|
||||
ids: int,
|
||||
idx: int,
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = [],
|
||||
):
|
||||
"""add logits processor to GuidedDecoding"""
|
||||
with self.logits_lock:
|
||||
if future is None:
|
||||
if ids in self.logits_processor:
|
||||
del self.logits_processor[ids]
|
||||
return
|
||||
"""add logits processor to SamplerProcessor"""
|
||||
self._prefill_done_idxs[idx] = False
|
||||
|
||||
if isinstance(future, LogitsProcessorBase):
|
||||
self.logits_processor[ids] = future
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[ids].accept_token(token)
|
||||
elif future.done():
|
||||
self.logits_processor[ids] = future.result()
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[ids].accept_token(token)
|
||||
else:
|
||||
self.logits_processor[ids] = [future, prefill_tokens]
|
||||
|
||||
def update_vocab_mask(self, skip_idx_list: List[int] = []):
|
||||
"""update vocab mask. (cpu-heavy operation)"""
|
||||
if len(self.logits_processor) == 0:
|
||||
if future is None:
|
||||
# normal request without guided_backend
|
||||
self.logits_processors[idx] = None
|
||||
return
|
||||
|
||||
with self.logits_lock:
|
||||
for idx, processor in self.logits_processor.items():
|
||||
if processor is None:
|
||||
del self.logits_processor[idx]
|
||||
continue
|
||||
if len(prefill_tokens) != 0:
|
||||
# first_token from prefill node
|
||||
self._prefill_done_idxs[idx] = True
|
||||
|
||||
if not isinstance(processor, LogitsProcessorBase):
|
||||
future, prefill_tokens = self.logits_processor[idx]
|
||||
self.logits_processor[idx] = future.result()
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[idx].accept_token(token)
|
||||
if future.done():
|
||||
# cached xgrammar
|
||||
self.logits_processors[idx] = future.result()
|
||||
for token in prefill_tokens:
|
||||
self._accept_token(idx, token)
|
||||
else:
|
||||
# async
|
||||
self.logits_processors[idx] = future
|
||||
self._tokens_to_acc[idx] = prefill_tokens
|
||||
|
||||
available_processors = None
|
||||
for processor in self.logits_processor.values():
|
||||
if processor.is_terminated():
|
||||
continue
|
||||
available_processors = processor
|
||||
if available_processors is None:
|
||||
return
|
||||
def should_fill_bitmask(self, idx: int) -> bool:
|
||||
"""
|
||||
Determines whether to fill a bitmask for the logits processor at the given index.
|
||||
|
||||
# allocate token bitmask
|
||||
self.token_bitmask = available_processors.allocate_token_bitmask()
|
||||
Args:
|
||||
idx (int): The index of the logits processor to check
|
||||
|
||||
with self.logits_lock:
|
||||
# fill token bitmask
|
||||
for idx, processor in self.logits_processor.items():
|
||||
if processor.is_terminated() or idx in skip_idx_list:
|
||||
continue
|
||||
Returns:
|
||||
bool: True if the idx request bitmask should be filled
|
||||
|
||||
"""
|
||||
if self.reasoning_parser is not None:
|
||||
if self.logits_processors[idx].enable_reasoning: # <think> guided
|
||||
return True
|
||||
if not self.logits_processors[idx].reasoning_ended:
|
||||
return False
|
||||
return True
|
||||
|
||||
def reset_processor(self, idx: int):
|
||||
"""reset idx"""
|
||||
self._prefill_done_idxs[idx] = False
|
||||
self.logits_processors[idx] = None
|
||||
|
||||
def update_vocab_mask(self, prefill_done_idxs: List[int] = []):
|
||||
"""update vocab mask. (cpu-heavy operation)"""
|
||||
for idx in prefill_done_idxs:
|
||||
if self.logits_processors[idx] is None:
|
||||
continue
|
||||
|
||||
assert not self._prefill_done_idxs[idx]
|
||||
self._prefill_done_idxs[idx] = True
|
||||
if isinstance(self.logits_processors[idx], Future):
|
||||
continue
|
||||
|
||||
for idx, processor in enumerate(self.logits_processors):
|
||||
if processor is None or not self._prefill_done_idxs[idx]:
|
||||
continue
|
||||
# skip, join at apply_token_mask
|
||||
if isinstance(processor, Future):
|
||||
continue
|
||||
if processor.is_terminated:
|
||||
self.reset_processor(idx)
|
||||
continue
|
||||
|
||||
self.accept_tokens_from_prefill_node(idx)
|
||||
|
||||
if self.token_bitmask is None:
|
||||
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
|
||||
|
||||
if self.should_fill_bitmask(idx):
|
||||
processor.fill_token_bitmask(self.token_bitmask, idx)
|
||||
|
||||
def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
"""apply token mask to logits"""
|
||||
if len(self.logits_processor) == 0 or self.token_bitmask is None:
|
||||
return logits
|
||||
def accept_tokens_from_prefill_node(self, idx: int):
|
||||
"""accept prefill token, not future"""
|
||||
if self._tokens_to_acc[idx] is not None:
|
||||
# accept token from prefill node first
|
||||
for token in self._tokens_to_acc[idx]:
|
||||
self._accept_token(idx, token)
|
||||
self._tokens_to_acc[idx] = None
|
||||
|
||||
# self.async_step.result()
|
||||
available_processors = None
|
||||
with self.logits_lock:
|
||||
for processor in self.logits_processor.values():
|
||||
if processor.is_terminated():
|
||||
continue
|
||||
available_processors = processor
|
||||
if available_processors is None:
|
||||
return logits
|
||||
def apply_token_mask(self, logits: paddle.Tensor, prefill_done_idxs: List[int] = []):
|
||||
"""apply token mask to logits"""
|
||||
|
||||
indices = []
|
||||
for idx, processor in self.logits_processor.items():
|
||||
if processor is None or idx in skip_idx_list:
|
||||
for idx, processor in enumerate(self.logits_processors):
|
||||
if processor is None or not self._prefill_done_idxs[idx]:
|
||||
continue
|
||||
if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended:
|
||||
indices.append(idx)
|
||||
|
||||
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices)
|
||||
# compiled done, check idx should fill, fill_token_bitmask done in preprocess
|
||||
if not isinstance(processor, Future):
|
||||
if self.should_fill_bitmask(idx):
|
||||
indices.append(idx)
|
||||
continue
|
||||
|
||||
# is Future, processor async compiled not ready, need join and wait
|
||||
ts = time.time()
|
||||
wait = False
|
||||
if not processor.done():
|
||||
wait = True
|
||||
self.logits_processors[idx] = processor.result()
|
||||
if wait:
|
||||
logger.debug(f"[{idx} join async compile xgrammar, time_cost:{time.time() - ts}]")
|
||||
|
||||
self.accept_tokens_from_prefill_node(idx)
|
||||
# Possible optimization: Extract 'think' content validation from logits_processors,
|
||||
# allowing join operations to complete immediately after 'think' terminates.
|
||||
# Furthermore, the current idx could be skipped, with compilation overhead
|
||||
# estimated at only a few milliseconds.
|
||||
|
||||
# check idx for fill_token_mask
|
||||
if not self.should_fill_bitmask(idx):
|
||||
continue
|
||||
|
||||
indices.append(idx)
|
||||
|
||||
if self.token_bitmask is None:
|
||||
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
|
||||
|
||||
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
|
||||
|
||||
if len(indices) == 0:
|
||||
return logits
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
|
||||
apply_token_mask,
|
||||
)
|
||||
|
||||
return apply_token_mask(logits, self.token_bitmask, indices=indices)
|
||||
|
||||
def _accept_token(self, idx: int, token: int):
|
||||
"""accept token"""
|
||||
if idx not in self.logits_processor:
|
||||
raise ValueError(f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}")
|
||||
|
||||
if self.logits_processor[idx].is_terminated():
|
||||
return
|
||||
if self.reasoning_parser is not None:
|
||||
if not self.logits_processors[idx].enable_reasoning:
|
||||
if not self.logits_processors[idx].reasoning_ended:
|
||||
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
|
||||
self.logits_processors[idx].reasoning_ended = reasoning_ended
|
||||
return
|
||||
|
||||
if (
|
||||
self.reasoning_parser is not None
|
||||
and self.logits_processor[idx].enable_reasoning
|
||||
and not self.logits_processor[idx].reasoning_ended
|
||||
):
|
||||
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
|
||||
self.logits_processor[idx].reasoning_ended = reasoning_ended
|
||||
return
|
||||
if not self.logits_processors[idx].accept_token(token) or self.logits_processors[idx].is_terminated:
|
||||
self.reset_processor(idx)
|
||||
|
||||
self.logits_processor[idx].accept_token(token)
|
||||
|
||||
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
def update_output_tokens(self, next_tokens: paddle.Tensor):
|
||||
"""update output tokens"""
|
||||
if len(self.logits_processor) == 0:
|
||||
if len(self.logits_processors) == 0:
|
||||
return
|
||||
|
||||
token_ids = next_tokens.numpy().tolist()
|
||||
with self.logits_lock:
|
||||
for idx in self.logits_processor.keys():
|
||||
token = token_ids[idx][0]
|
||||
if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list:
|
||||
continue
|
||||
for idx, processor in enumerate(self.logits_processors):
|
||||
if not self._prefill_done_idxs[idx] or processor is None:
|
||||
continue
|
||||
if idx >= len(token_ids):
|
||||
continue
|
||||
token = token_ids[idx][0]
|
||||
if token < 0:
|
||||
self.reset_processor(idx)
|
||||
continue
|
||||
logger.debug(f"[{idx}]accept token{token}")
|
||||
self._accept_token(idx, token)
|
||||
|
||||
self._accept_token(idx, token)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
def pre_process(self, prefill_done_idxs: List[int] = []):
|
||||
"""pre process before running"""
|
||||
# create async operation for guided decoding
|
||||
# TODO: support async
|
||||
self.update_vocab_mask(skip_idx_list)
|
||||
# self.async_step = self.executor.submit(self.update_vocab_mask)
|
||||
self.update_vocab_mask(prefill_done_idxs)
|
||||
|
||||
|
||||
class Sampler(nn.Layer):
|
||||
@@ -224,7 +272,7 @@ class Sampler(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.guided_decoding = GuidedDecoding()
|
||||
self.guided_decoding = GuidedDecoding(fd_config)
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
@@ -240,17 +288,19 @@ class Sampler(nn.Layer):
|
||||
"""set reasoning parser"""
|
||||
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
|
||||
|
||||
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
|
||||
def apply_logits_processor(
|
||||
self, ids: int, future: Future[LogitsProcessorBase] = None, prefill_tokens: List[int] = []
|
||||
):
|
||||
"""apply logits processor to sampler"""
|
||||
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
def pre_process(self, prefill_done_idxs: List[int] = []):
|
||||
"""pre process before running"""
|
||||
self.guided_decoding.pre_process(skip_idx_list)
|
||||
self.guided_decoding.pre_process(prefill_done_idxs)
|
||||
|
||||
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
def post_process(self, next_tokens: paddle.Tensor):
|
||||
"""post process after running"""
|
||||
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
|
||||
self.guided_decoding.update_output_tokens(next_tokens)
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
@@ -343,10 +393,10 @@ class Sampler(nn.Layer):
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
skip_idx_list: List[int] = [],
|
||||
p_done_idxs: List[int] = [],
|
||||
) -> SamplerOutput:
|
||||
""" """
|
||||
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
|
||||
logits = self.guided_decoding.apply_token_mask(logits, p_done_idxs)
|
||||
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
|
||||
@@ -706,10 +706,30 @@ class InferScheduler:
|
||||
self.reqs_queue = deque()
|
||||
self.writers = []
|
||||
|
||||
def check_redis_version(self):
|
||||
# Get Redis version information
|
||||
redis_info = self.client.info()
|
||||
redis_version = redis_info.get("redis_version", "")
|
||||
version_parts = [int(x) for x in redis_version.split(".")]
|
||||
|
||||
# Redis 6.2 and above versions support RPOP with count parameter
|
||||
assert (
|
||||
version_parts[0] >= 6
|
||||
), f"Redis major version too low: {version_parts[0]}. Please upgrade to Redis 6.2+ to support batch RPOP operations."
|
||||
assert (
|
||||
version_parts[1] >= 2 if version_parts[0] == 6 else True
|
||||
), f"Redis version {redis_version} too low. Please upgrade to Redis 6.2+ to support batch RPOP operations."
|
||||
|
||||
logger.info(f"Redis version {redis_version} detected. Using native batch RPOP.")
|
||||
|
||||
def start(self, role, host, disaggregated):
|
||||
"""
|
||||
start backup threads
|
||||
"""
|
||||
|
||||
# Check Redis version first
|
||||
self.check_redis_version()
|
||||
|
||||
for i in range(self.writer_parallel):
|
||||
writer = ResultWriter(self.client, i, self.writer_batch_size, self.ttl)
|
||||
writer.start()
|
||||
|
||||
@@ -31,9 +31,6 @@ from fastdeploy.model_executor.graph_optimization.utils import (
|
||||
sot_warmup_guard,
|
||||
)
|
||||
from fastdeploy.model_executor.guided_decoding import get_guided_backend
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
@@ -182,7 +179,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
or request.guided_grammar is not None
|
||||
):
|
||||
logits_info, schemata_key = self._init_logits_processor(request)
|
||||
request.logits_processor, request.logits_cached = logits_info
|
||||
request.logits_processor = logits_info
|
||||
request.schemata_key = schemata_key
|
||||
|
||||
# Is Decode Node
|
||||
@@ -925,29 +922,36 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
|
||||
|
||||
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
|
||||
def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int):
|
||||
"""
|
||||
Get the index of the request that needs to be skipped during execution.
|
||||
Args:
|
||||
model_forward_batch: A list of requests to be executed by this runner.
|
||||
Returns:
|
||||
A list of indices corresponding to the requests that need to be skipped.
|
||||
Returns indices for guided decoding.
|
||||
When Prefill is done, async compiled logits_processor must be joined.
|
||||
"""
|
||||
skip_idx_list = []
|
||||
if not self.cache_config.enable_chunked_prefill or self.guided_backend is None:
|
||||
return skip_idx_list
|
||||
if self.guided_backend is None:
|
||||
return []
|
||||
|
||||
for task in model_forward_batch:
|
||||
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
prefill_done_idxs = []
|
||||
for idx in range(0, num_running_requests):
|
||||
if self.share_inputs["step_idx"][idx] == 0:
|
||||
prefill_done_idxs.append(idx)
|
||||
|
||||
for task in self.restore_chunked_prefill_request.values():
|
||||
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
if model_forward_batch is not None:
|
||||
for task in model_forward_batch:
|
||||
# new Request with ChunkPrefill, unfinished, store
|
||||
if task.chunk_idx < len(task.prefill_chunk_info):
|
||||
if task.request_id not in self.restore_chunked_prefill_request:
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
|
||||
return skip_idx_list
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
# unfinished, remove
|
||||
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
|
||||
prefill_done_idxs.remove(task.idx)
|
||||
# finished, add
|
||||
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
|
||||
prefill_done_idxs.append(task.idx)
|
||||
|
||||
return prefill_done_idxs
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -971,9 +975,9 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
return None
|
||||
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
self.sampler.pre_process(p_done_idxs)
|
||||
|
||||
# 2. Padding inputs for cuda graph
|
||||
|
||||
@@ -1009,7 +1013,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
skip_idx_list,
|
||||
p_done_idxs,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
|
||||
@@ -1081,28 +1085,9 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
|
||||
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
self.seq_lens_this_time_buffer.copy_(self.share_inputs["seq_lens_this_time"], False)
|
||||
return None
|
||||
|
||||
def _add_cache(self, model_forward_batch) -> None:
|
||||
"""
|
||||
Add cache for guided decoding.
|
||||
"""
|
||||
if self.guided_backend is None:
|
||||
return
|
||||
|
||||
for request in model_forward_batch:
|
||||
logits_cached = request.get("logits_cached", None)
|
||||
if logits_cached is None or logits_cached:
|
||||
continue
|
||||
|
||||
request.logits_cached = True
|
||||
if isinstance(request.logits_processor, LogitsProcessorBase):
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
|
||||
else:
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
|
||||
|
||||
def _execute_empty_input(self) -> None:
|
||||
"""
|
||||
In certain scenarios, such as during EP,
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import os
|
||||
import queue
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from threading import Thread
|
||||
from typing import List, Optional, cast
|
||||
|
||||
@@ -287,7 +288,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.proposer = None
|
||||
|
||||
def _init_logits_processor(self, request):
|
||||
def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]:
|
||||
"""
|
||||
init logits processor for guided decoding
|
||||
"""
|
||||
@@ -307,7 +308,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
return (
|
||||
self.guided_backend.get_logits_processor(
|
||||
schemata_key=schemata_key,
|
||||
enable_thinking=True,
|
||||
enable_thinking=False, # TODO cfg
|
||||
),
|
||||
schemata_key,
|
||||
)
|
||||
@@ -696,6 +697,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
length = len(request.prompt_token_ids)
|
||||
assert length > 0, "The prompt requested must not be empty."
|
||||
|
||||
logits_info = None
|
||||
prefill_tokens = []
|
||||
if (
|
||||
request.guided_json is not None
|
||||
@@ -704,7 +706,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
or request.guided_grammar is not None
|
||||
):
|
||||
logits_info, schemata_key = self._init_logits_processor(request)
|
||||
request.logits_processor, request.logits_cached = logits_info
|
||||
request.schemata_key = schemata_key
|
||||
|
||||
# Is Decode Node
|
||||
@@ -874,7 +875,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens)
|
||||
self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)
|
||||
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
@@ -2006,34 +2007,36 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
|
||||
|
||||
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
|
||||
def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int):
|
||||
"""
|
||||
Get the index of the request that needs to be skipped during execution.
|
||||
Args:
|
||||
model_forward_batch: A list of requests to be executed by this runner.
|
||||
Returns:
|
||||
A list of indices corresponding to the requests that need to be skipped.
|
||||
Get indices for guided decoding.
|
||||
When Prefill is done, async compiled logits_processor must be joined.
|
||||
"""
|
||||
if (
|
||||
not self.cache_config.enable_chunked_prefill
|
||||
or self.guided_backend is None
|
||||
or model_forward_batch is None
|
||||
or envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
):
|
||||
if self.guided_backend is None:
|
||||
return []
|
||||
|
||||
skip_idx_list = []
|
||||
for task in model_forward_batch:
|
||||
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
prefill_done_idxs = []
|
||||
for idx in range(0, num_running_requests):
|
||||
if self.share_inputs["step_idx"][idx] == 0:
|
||||
prefill_done_idxs.append(idx)
|
||||
|
||||
for task in self.restore_chunked_prefill_request.values():
|
||||
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
if model_forward_batch is not None:
|
||||
for task in model_forward_batch:
|
||||
# new Request with ChunkPrefill, unfinished, store
|
||||
if task.chunk_idx < len(task.prefill_chunk_info):
|
||||
if task.request_id not in self.restore_chunked_prefill_request:
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
|
||||
return skip_idx_list
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
# unfinished, remove
|
||||
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
|
||||
prefill_done_idxs.remove(task.idx)
|
||||
# finished, add
|
||||
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
|
||||
prefill_done_idxs.append(task.idx)
|
||||
|
||||
return prefill_done_idxs
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -2050,9 +2053,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
num_running_requests: batch_size
|
||||
"""
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
|
||||
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
self.sampler.pre_process(p_done_idxs)
|
||||
|
||||
# 1.1 Update state of logits processor
|
||||
for proc in self.sampling_metadata.logits_processors:
|
||||
@@ -2157,7 +2161,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
skip_idx_list,
|
||||
p_done_idxs,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(
|
||||
@@ -2244,7 +2248,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
)
|
||||
if self.guided_backend is not None and sampler_output is not None:
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids)
|
||||
|
||||
# 6. Speculative decode
|
||||
if self.speculative_decoding:
|
||||
@@ -2268,7 +2272,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
elif self.speculative_decoding:
|
||||
speculate_schedule_cache(
|
||||
self.share_inputs["draft_tokens"],
|
||||
@@ -2325,24 +2328,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
return pooler_output
|
||||
|
||||
def _add_cache(self, model_forward_batch) -> None:
|
||||
"""
|
||||
Add cache for guided decoding.
|
||||
"""
|
||||
if self.guided_backend is None or model_forward_batch is None:
|
||||
return
|
||||
|
||||
for request in model_forward_batch:
|
||||
logits_cached = request.get("logits_cached", None)
|
||||
if logits_cached is None or logits_cached:
|
||||
continue
|
||||
|
||||
request.logits_cached = True
|
||||
if isinstance(request.logits_processor, LogitsProcessorBase):
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
|
||||
else:
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
|
||||
|
||||
def _execute_empty_input(self) -> None:
|
||||
"""
|
||||
In certain scenarios, such as during EP,
|
||||
|
||||
@@ -30,9 +30,6 @@ from fastdeploy.engine.request import Request
|
||||
# from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||
from fastdeploy.model_executor.forward_meta import HPUForwardMeta
|
||||
from fastdeploy.model_executor.guided_decoding import get_guided_backend
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
@@ -442,7 +439,7 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
or request.guided_grammar is not None
|
||||
):
|
||||
logits_info, schemata_key = self._init_logits_processor(request)
|
||||
request.logits_processor, request.logits_cached = logits_info
|
||||
request.logits_processor = logits_info
|
||||
request.schemata_key = schemata_key
|
||||
|
||||
# Is Decode Node
|
||||
@@ -1179,30 +1176,6 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
time_after_capture = time.perf_counter()
|
||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||
|
||||
def _get_skip_idx(self, model_forward_batch):
|
||||
"""
|
||||
Get the index of the request that needs to be skipped during execution.
|
||||
Args:
|
||||
model_forward_batch: A list of requests to be executed by this runner.
|
||||
Returns:
|
||||
A list of indices corresponding to the requests that need to be skipped.
|
||||
"""
|
||||
skip_idx_list = []
|
||||
if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None:
|
||||
return skip_idx_list
|
||||
|
||||
for task in model_forward_batch:
|
||||
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
|
||||
for task in self.restore_chunked_prefill_request.values():
|
||||
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
|
||||
return skip_idx_list
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_forward_batch: Optional[List[Request]] = None,
|
||||
@@ -1332,30 +1305,11 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
execution_time = (end_time - start_time) * 1000
|
||||
hpu_model_runner_profile_logger.info(f"StepPaddle execution time(ms): {execution_time}, BT={real_bs}")
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
|
||||
if int(os.environ.get("HABANA_PROFILE", 0)) == 1:
|
||||
self.prof.step()
|
||||
return None
|
||||
|
||||
def _add_cache(self, model_forward_batch) -> None:
|
||||
"""
|
||||
Add cache for guided decoding.
|
||||
"""
|
||||
if self.guided_backend is None:
|
||||
return
|
||||
|
||||
for request in model_forward_batch:
|
||||
logits_cached = request.get("logits_cached", None)
|
||||
if logits_cached is None or logits_cached:
|
||||
continue
|
||||
|
||||
request.logits_cached = True
|
||||
if isinstance(request.logits_processor, LogitsProcessorBase):
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
|
||||
else:
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
|
||||
|
||||
def _execute_empty_input(self) -> None:
|
||||
"""
|
||||
In certain scenarios, such as during EP,
|
||||
|
||||
@@ -38,10 +38,7 @@ from fastdeploy.model_executor.graph_optimization.utils import (
|
||||
profile_run_guard,
|
||||
sot_warmup_guard,
|
||||
)
|
||||
from fastdeploy.model_executor.guided_decoding import (
|
||||
LogitsProcessorBase,
|
||||
get_guided_backend,
|
||||
)
|
||||
from fastdeploy.model_executor.guided_decoding import get_guided_backend
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
@@ -507,7 +504,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
or request.guided_grammar is not None
|
||||
):
|
||||
logits_info, schemata_key = self._init_logits_processor(request)
|
||||
request.logits_processor, request.logits_cached = logits_info
|
||||
request.logits_processor = logits_info
|
||||
request.schemata_key = schemata_key
|
||||
|
||||
# Is Decode Node
|
||||
@@ -1761,34 +1758,36 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
|
||||
|
||||
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
|
||||
def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int):
|
||||
"""
|
||||
Get the index of the request that needs to be skipped during execution.
|
||||
Args:
|
||||
model_forward_batch: A list of requests to be executed by this runner.
|
||||
Returns:
|
||||
A list of indices corresponding to the requests that need to be skipped.
|
||||
Get indices for guided decoding.
|
||||
When Prefill is done, async compiled logits_processor must be joined.
|
||||
"""
|
||||
if (
|
||||
not self.cache_config.enable_chunked_prefill
|
||||
or self.guided_backend is None
|
||||
or model_forward_batch is None
|
||||
or envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
):
|
||||
if self.guided_backend is None:
|
||||
return []
|
||||
|
||||
skip_idx_list = []
|
||||
for task in model_forward_batch:
|
||||
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
prefill_done_idxs = []
|
||||
for idx in range(0, num_running_requests):
|
||||
if self.share_inputs["step_idx"][idx] == 0:
|
||||
prefill_done_idxs.append(idx)
|
||||
|
||||
for task in self.restore_chunked_prefill_request.values():
|
||||
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
if model_forward_batch is not None:
|
||||
for task in model_forward_batch:
|
||||
# new Request with ChunkPrefill, unfinished, store
|
||||
if task.chunk_idx < len(task.prefill_chunk_info):
|
||||
if task.request_id not in self.restore_chunked_prefill_request:
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
|
||||
return skip_idx_list
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
# unfinished, remove
|
||||
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
|
||||
prefill_done_idxs.remove(task.idx)
|
||||
# finished, add
|
||||
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
|
||||
prefill_done_idxs.append(task.idx)
|
||||
|
||||
return prefill_done_idxs
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -1805,9 +1804,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
num_running_requests: batch_size
|
||||
"""
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
self.sampler.pre_process(p_done_idxs)
|
||||
|
||||
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||
@@ -1864,7 +1863,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
skip_idx_list,
|
||||
p_done_idxs,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(
|
||||
@@ -1949,7 +1948,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
)
|
||||
if self.guided_backend is not None and sampler_output is not None:
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids)
|
||||
|
||||
# 6. Speculative decode
|
||||
if self.speculative_decoding:
|
||||
@@ -1973,7 +1972,6 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
elif self.speculative_decoding:
|
||||
speculate_schedule_cache(
|
||||
self.share_inputs["draft_tokens"],
|
||||
@@ -2000,24 +1998,6 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
)
|
||||
return None
|
||||
|
||||
def _add_cache(self, model_forward_batch) -> None:
|
||||
"""
|
||||
Add cache for guided decoding.
|
||||
"""
|
||||
if self.guided_backend is None or model_forward_batch is None:
|
||||
return
|
||||
|
||||
for request in model_forward_batch:
|
||||
logits_cached = request.get("logits_cached", None)
|
||||
if logits_cached is None or logits_cached:
|
||||
continue
|
||||
|
||||
request.logits_cached = True
|
||||
if isinstance(request.logits_processor, LogitsProcessorBase):
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
|
||||
else:
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
|
||||
|
||||
def _execute_empty_input(self) -> None:
|
||||
"""
|
||||
In certain scenarios, such as during EP,
|
||||
|
||||
338
tests/layers/test_guided_decoding.py
Normal file
338
tests/layers/test_guided_decoding.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
测试GuidedDecoding类的单元测试
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import paddle
|
||||
|
||||
mock_torch = MagicMock()
|
||||
mock_xgrammar = MagicMock()
|
||||
sys.modules["torch"] = mock_torch
|
||||
sys.modules["xgrammar"] = mock_xgrammar
|
||||
|
||||
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
|
||||
from fastdeploy.model_executor.layers.sample.sampler import GuidedDecoding
|
||||
from fastdeploy.reasoning import ReasoningParser
|
||||
|
||||
|
||||
class TestGuidedDecoding(unittest.TestCase):
|
||||
"""Test cases for GuidedDecoding class."""
|
||||
|
||||
def setUp(self):
|
||||
"""Setup for each test case."""
|
||||
# 创建一个基本的FDConfig对象
|
||||
self.fd_config = Mock()
|
||||
self.fd_config.scheduler_config.max_num_seqs = 5
|
||||
|
||||
# 创建GuidedDecoding对象
|
||||
self.guided_decoding = GuidedDecoding(self.fd_config)
|
||||
|
||||
# 创建一个模拟的LogitsProcessorBase
|
||||
self.mock_processor = Mock(spec=LogitsProcessorBase)
|
||||
self.mock_processor.is_terminated = False
|
||||
self.mock_processor.reasoning_ended = True
|
||||
self.mock_processor.enable_reasoning = False
|
||||
|
||||
# 模拟allocate_token_bitmask方法返回一个假的bitmask
|
||||
self.mock_processor.allocate_token_bitmask.return_value = paddle.zeros([5, 10], dtype="int32")
|
||||
|
||||
# 模拟fill_token_bitmask方法
|
||||
self.mock_processor.fill_token_bitmask.return_value = None
|
||||
|
||||
# 模拟accept_token方法返回True
|
||||
self.mock_processor.accept_token.return_value = True
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization."""
|
||||
self.assertIsNone(self.guided_decoding.token_bitmask)
|
||||
self.assertEqual(len(self.guided_decoding.logits_processors), 5)
|
||||
self.assertIsNone(self.guided_decoding.reasoning_parser)
|
||||
self.assertEqual(len(self.guided_decoding._prefill_done_idxs), 5)
|
||||
self.assertEqual(len(self.guided_decoding._tokens_to_acc), 5)
|
||||
|
||||
def test_apply_reasoning_parser(self):
|
||||
"""Test apply_reasoning_parser method."""
|
||||
mock_parser = Mock(spec=ReasoningParser)
|
||||
self.guided_decoding.apply_reasoning_parser(mock_parser)
|
||||
self.assertEqual(self.guided_decoding.reasoning_parser, mock_parser)
|
||||
|
||||
def test_add_logits_processor_no_future(self):
|
||||
"""Test add_logits_processor method without future."""
|
||||
self.guided_decoding.add_logits_processor(0, None, [])
|
||||
self.assertFalse(self.guided_decoding._prefill_done_idxs[0])
|
||||
self.assertIsNone(self.guided_decoding.logits_processors[0])
|
||||
|
||||
def test_add_logits_processor_with_prefill_tokens(self):
|
||||
"""Test add_logits_processor method with prefill tokens."""
|
||||
# 创建模拟Future对象
|
||||
mock_future = Mock()
|
||||
mock_future.done.return_value = True
|
||||
mock_future.result.return_value = self.mock_processor
|
||||
|
||||
prefill_tokens = [1, 2, 3]
|
||||
self.guided_decoding.add_logits_processor(0, mock_future, prefill_tokens)
|
||||
|
||||
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
|
||||
self.assertEqual(self.guided_decoding.logits_processors[0], self.mock_processor)
|
||||
self.mock_processor.accept_token.assert_any_call(1)
|
||||
self.mock_processor.accept_token.assert_any_call(2)
|
||||
self.mock_processor.accept_token.assert_any_call(3)
|
||||
|
||||
def test_add_logits_processor_with_async_future(self):
|
||||
"""Test add_logits_processor method with async future."""
|
||||
# 创建模拟Future对象
|
||||
mock_future = Mock()
|
||||
mock_future.done.return_value = False
|
||||
|
||||
prefill_tokens = [1, 2, 3]
|
||||
self.guided_decoding.add_logits_processor(0, mock_future, prefill_tokens)
|
||||
|
||||
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
|
||||
self.assertEqual(self.guided_decoding.logits_processors[0], mock_future)
|
||||
self.assertEqual(self.guided_decoding._tokens_to_acc[0], prefill_tokens)
|
||||
|
||||
def test_should_fill_bitmask_no_reasoning_parser(self):
|
||||
"""Test should_fill_bitmask method with no reasoning parser."""
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
|
||||
|
||||
def test_should_fill_bitmask_with_reasoning_parser(self):
|
||||
"""Test should_fill_bitmask method with reasoning parser."""
|
||||
mock_parser = Mock(spec=ReasoningParser)
|
||||
self.guided_decoding.reasoning_parser = mock_parser
|
||||
|
||||
# 测试 enable_reasoning=True 的情况
|
||||
self.mock_processor.enable_reasoning = True
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
|
||||
|
||||
# 测试 enable_reasoning=False, reasoning_ended=False 的情况
|
||||
self.mock_processor.enable_reasoning = False
|
||||
self.mock_processor.reasoning_ended = False
|
||||
self.assertFalse(self.guided_decoding.should_fill_bitmask(0))
|
||||
|
||||
# 测试 enable_reasoning=False, reasoning_ended=True 的情况
|
||||
self.mock_processor.reasoning_ended = True
|
||||
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
|
||||
|
||||
def test_reset_processor(self):
|
||||
"""Test reset_processor method."""
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
self.guided_decoding._prefill_done_idxs[0] = True
|
||||
|
||||
self.guided_decoding.reset_processor(0)
|
||||
|
||||
self.assertFalse(self.guided_decoding._prefill_done_idxs[0])
|
||||
self.assertIsNone(self.guided_decoding.logits_processors[0])
|
||||
|
||||
def test_update_vocab_mask_with_new_prefill_done(self):
|
||||
"""Test update_vocab_mask method with new prefill_done_idxs."""
|
||||
# 设置索引0的处理器
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
self.guided_decoding._prefill_done_idxs[0] = False
|
||||
|
||||
# 调用update_vocab_mask并标记索引0为已完成
|
||||
self.guided_decoding.update_vocab_mask([0])
|
||||
|
||||
# 验证_prefill_done_idxs[0]已更新
|
||||
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
|
||||
|
||||
# 验证fill_token_bitmask被调用
|
||||
self.mock_processor.fill_token_bitmask.assert_called_once()
|
||||
|
||||
def test_update_vocab_mask_with_future_processor(self):
|
||||
"""Test update_vocab_mask method with future processor."""
|
||||
# 创建模拟Future对象
|
||||
mock_future = Mock()
|
||||
|
||||
# 设置索引0的处理器为Future
|
||||
self.guided_decoding.logits_processors[0] = mock_future
|
||||
self.guided_decoding._prefill_done_idxs[0] = True
|
||||
|
||||
# 调用update_vocab_mask
|
||||
self.guided_decoding.update_vocab_mask([])
|
||||
|
||||
# 验证fill_token_bitmask没有被调用(因为处理器是Future)
|
||||
self.mock_processor.fill_token_bitmask.assert_not_called()
|
||||
|
||||
def test_accept_tokens_from_prefill_node(self):
|
||||
"""Test accept_tokens_from_prefill_node method."""
|
||||
# 设置索引0的处理器和待接受的tokens
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
self.guided_decoding._tokens_to_acc[0] = [1, 2, 3]
|
||||
|
||||
# 调用accept_tokens_from_prefill_node
|
||||
self.guided_decoding.accept_tokens_from_prefill_node(0)
|
||||
|
||||
# 验证accept_token被调用了3次
|
||||
self.assertEqual(self.mock_processor.accept_token.call_count, 3)
|
||||
self.mock_processor.accept_token.assert_any_call(1)
|
||||
self.mock_processor.accept_token.assert_any_call(2)
|
||||
self.mock_processor.accept_token.assert_any_call(3)
|
||||
|
||||
# 验证_tokens_to_acc[0]已被重置
|
||||
self.assertIsNone(self.guided_decoding._tokens_to_acc[0])
|
||||
|
||||
@patch("fastdeploy.model_executor.guided_decoding.xgrammar_backend.apply_token_mask")
|
||||
def test_apply_token_mask(self, mock_apply_token_mask):
|
||||
"""Test apply_token_mask method."""
|
||||
# 创建测试数据
|
||||
logits = paddle.zeros([5, 10], dtype="float32")
|
||||
mock_apply_token_mask.return_value = paddle.ones([5, 10], dtype="float32")
|
||||
|
||||
# 设置索引0的处理器
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
self.guided_decoding._prefill_done_idxs[0] = True
|
||||
|
||||
# 调用apply_token_mask
|
||||
result = self.guided_decoding.apply_token_mask(logits, [])
|
||||
|
||||
# 验证fill_token_bitmask没有被调用,非 Future
|
||||
self.mock_processor.fill_token_bitmask.assert_not_called()
|
||||
|
||||
# 验证apply_token_mask被调用
|
||||
mock_apply_token_mask.assert_called_once()
|
||||
|
||||
# 验证返回值
|
||||
self.assertTrue((result == paddle.ones([5, 10], dtype="float32")).all())
|
||||
|
||||
def test_apply_token_mask_with_future_processor(self):
|
||||
"""Test apply_token_mask method with future processor."""
|
||||
# 创建测试数据
|
||||
logits = paddle.zeros([5, 10], dtype="float32")
|
||||
|
||||
# 创建模拟Future对象
|
||||
mock_future = Mock(spec=Future)
|
||||
mock_future.done.return_value = True
|
||||
mock_future.result.return_value = self.mock_processor
|
||||
|
||||
# 设置索引0的处理器为Future
|
||||
self.guided_decoding.logits_processors[0] = mock_future
|
||||
|
||||
self.guided_decoding._prefill_done_idxs[0] = True
|
||||
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
|
||||
self.assertIsNotNone(self.guided_decoding.logits_processors[0])
|
||||
self.assertTrue(isinstance(self.guided_decoding.logits_processors[0], Future))
|
||||
self.guided_decoding._tokens_to_acc[0] = [1, 2, 3]
|
||||
|
||||
# 模拟patch apply_token_mask
|
||||
with patch(
|
||||
"fastdeploy.model_executor.guided_decoding.xgrammar_backend.apply_token_mask"
|
||||
) as mock_apply_token_mask:
|
||||
mock_apply_token_mask.return_value = paddle.ones([5, 10], dtype="float32")
|
||||
|
||||
# 调用apply_token_mask
|
||||
self.guided_decoding.apply_token_mask(logits, [])
|
||||
|
||||
# 验证Future.result被调用
|
||||
mock_future.result.assert_called_once()
|
||||
|
||||
# 验证accept_token被调用了3次
|
||||
self.assertEqual(self.mock_processor.accept_token.call_count, 3)
|
||||
|
||||
# 验证_tokens_to_acc[0]已被重置
|
||||
self.assertIsNone(self.guided_decoding._tokens_to_acc[0])
|
||||
|
||||
def test_accept_token(self):
|
||||
"""Test _accept_token method."""
|
||||
# 设置索引0的处理器
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
|
||||
# 调用_accept_token
|
||||
self.guided_decoding._accept_token(0, 1)
|
||||
|
||||
# 验证accept_token被调用
|
||||
self.mock_processor.accept_token.assert_called_once_with(1)
|
||||
|
||||
def test_accept_token_with_reasoning_parser(self):
|
||||
"""Test _accept_token method with reasoning parser."""
|
||||
# 创建模拟ReasoningParser
|
||||
mock_parser = Mock(spec=ReasoningParser)
|
||||
mock_parser.is_reasoning_end.return_value = True
|
||||
self.guided_decoding.reasoning_parser = mock_parser
|
||||
|
||||
# 设置索引0的处理器
|
||||
self.mock_processor.enable_reasoning = False
|
||||
self.mock_processor.reasoning_ended = False
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
|
||||
# 调用_accept_token
|
||||
self.guided_decoding._accept_token(0, 1)
|
||||
|
||||
# 验证is_reasoning_end被调用
|
||||
mock_parser.is_reasoning_end.assert_called_once_with([1])
|
||||
|
||||
# 验证reasoning_ended已更新
|
||||
self.assertTrue(self.mock_processor.reasoning_ended)
|
||||
|
||||
# 验证accept_token没有被调用(因为reasoning_ended刚被设置为True)
|
||||
self.mock_processor.accept_token.assert_not_called()
|
||||
|
||||
def test_accept_token_processor_terminated(self):
|
||||
"""Test _accept_token method when processor is terminated."""
|
||||
# 设置索引0的处理器,并让accept_token返回False
|
||||
self.mock_processor.accept_token.return_value = False
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
|
||||
# 调用_accept_token
|
||||
self.guided_decoding._accept_token(0, 1)
|
||||
|
||||
# 验证处理器被重置
|
||||
self.assertIsNone(self.guided_decoding.logits_processors[0])
|
||||
|
||||
def test_update_output_tokens(self):
|
||||
"""Test update_output_tokens method."""
|
||||
# 创建测试数据
|
||||
next_tokens = paddle.to_tensor([[1], [2], [3], [4], [5]])
|
||||
|
||||
# 设置索引0和1的处理器
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
self.guided_decoding.logits_processors[1] = self.mock_processor
|
||||
self.guided_decoding._prefill_done_idxs[0] = True
|
||||
self.guided_decoding._prefill_done_idxs[1] = True
|
||||
|
||||
# 调用update_output_tokens
|
||||
self.guided_decoding.update_output_tokens(next_tokens)
|
||||
|
||||
# 验证accept_token被调用了两次
|
||||
self.assertEqual(self.mock_processor.accept_token.call_count, 2)
|
||||
self.mock_processor.accept_token.assert_any_call(1)
|
||||
self.mock_processor.accept_token.assert_any_call(2)
|
||||
|
||||
def test_update_output_tokens_with_negative_token(self):
|
||||
"""Test update_output_tokens method with negative token."""
|
||||
# 创建测试数据,包含负值
|
||||
next_tokens = paddle.to_tensor([[-1], [2]])
|
||||
|
||||
# 设置索引0和1的处理器
|
||||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||||
self.guided_decoding.logits_processors[1] = self.mock_processor
|
||||
self.guided_decoding._prefill_done_idxs[0] = True
|
||||
self.guided_decoding._prefill_done_idxs[1] = True
|
||||
|
||||
# 调用update_output_tokens
|
||||
self.guided_decoding.update_output_tokens(next_tokens)
|
||||
|
||||
# 验证索引0的处理器被重置
|
||||
self.assertIsNone(self.guided_decoding.logits_processors[0])
|
||||
|
||||
# 验证索引1的处理器的accept_token被调用
|
||||
self.mock_processor.accept_token.assert_called_once_with(2)
|
||||
|
||||
def test_pre_process(self):
|
||||
"""Test pre_process method."""
|
||||
# 模拟update_vocab_mask方法
|
||||
with patch.object(self.guided_decoding, "update_vocab_mask") as mock_update_vocab_mask:
|
||||
# 调用pre_process
|
||||
self.guided_decoding.pre_process([0, 1])
|
||||
|
||||
# 验证update_vocab_mask被调用
|
||||
mock_update_vocab_mask.assert_called_once_with([0, 1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -14,11 +14,23 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
|
||||
|
||||
def _create_fake_logits(batch_size: int, vocab_size: int) -> paddle.Tensor:
|
||||
@@ -67,13 +79,60 @@ def _create_default_sampling_metadata(
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
def build_config_json() -> str:
|
||||
config_dict = {
|
||||
"architectures": ["Qwen3MoeForCausalLM"],
|
||||
"hidden_size": 7168,
|
||||
"moe_intermediate_size": 1,
|
||||
"moe_num_experts": 1,
|
||||
"moe_k": 1,
|
||||
"hidden_act": "silu",
|
||||
"num_attention_heads": 64,
|
||||
"dtype": "bfloat16",
|
||||
}
|
||||
|
||||
tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}"
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
with open(f"./{tmp_dir}/config.json", "w") as f:
|
||||
json.dump(config_dict, f)
|
||||
model_name_or_path = os.path.join(os.getcwd(), tmp_dir)
|
||||
print("model_name_or_path", model_name_or_path)
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
def get_fd_config(batch_size: int):
|
||||
fd_config = FDConfig(
|
||||
model_config=ModelConfig(
|
||||
{
|
||||
"model": build_config_json(),
|
||||
"max_model_len": 2048,
|
||||
}
|
||||
),
|
||||
parallel_config=ParallelConfig(
|
||||
{
|
||||
"tensor_parallel_size": 1,
|
||||
"expert_parallel_size": 1,
|
||||
"expert_parallel_rank": 0,
|
||||
"data_parallel_size": 1,
|
||||
}
|
||||
),
|
||||
# quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]),
|
||||
scheduler_config=SchedulerConfig({"max_num_seqs": batch_size}),
|
||||
cache_config=CacheConfig({}),
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
load_config=LoadConfig({}),
|
||||
ips="0.0.0.0",
|
||||
)
|
||||
return fd_config
|
||||
|
||||
|
||||
def test_sampler():
|
||||
batch_size = 32
|
||||
vocab_size = 1024
|
||||
min_seq_len = 1
|
||||
max_seq_len = 1024
|
||||
|
||||
sampler = Sampler()
|
||||
sampler = Sampler(get_fd_config(batch_size))
|
||||
logits = _create_fake_logits(batch_size, vocab_size)
|
||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
|
||||
next_tokens = sampler(logits, sampling_metadata)
|
||||
@@ -144,7 +203,10 @@ def test_sampler_logprobs():
|
||||
logits = _create_fake_logits(batch_size, vocab_size)
|
||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
|
||||
for logprobs_mode in logprobs_mode_list:
|
||||
sampler = Sampler(logprobs_mode=logprobs_mode)
|
||||
fd_config = get_fd_config(batch_size)
|
||||
fd_config.model_config.logprobs_mode = logprobs_mode
|
||||
sampler = Sampler(logprobs_mode=logprobs_mode, fd_config=fd_config)
|
||||
assert sampler.logprobs_mode == logprobs_mode
|
||||
sampler_output = sampler(logits.clone(), sampling_metadata)
|
||||
baseline_logprobs = get_baseline_logprobs(
|
||||
logits.clone(), sampling_metadata, logprobs_mode=logprobs_mode, token_ids=sampler_output.sampled_token_ids
|
||||
|
||||
@@ -12,15 +12,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
)
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
@@ -82,6 +93,53 @@ class FakeModel:
|
||||
return paddle.matmul(x.astype("float32"), self.weight)
|
||||
|
||||
|
||||
def build_config_json() -> str:
|
||||
config_dict = {
|
||||
"architectures": ["Qwen3MoeForCausalLM"],
|
||||
"hidden_size": 7168,
|
||||
"moe_intermediate_size": 1,
|
||||
"moe_num_experts": 1,
|
||||
"moe_k": 1,
|
||||
"hidden_act": "silu",
|
||||
"num_attention_heads": 64,
|
||||
"dtype": "bfloat16",
|
||||
}
|
||||
|
||||
tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}"
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
with open(f"./{tmp_dir}/config.json", "w") as f:
|
||||
json.dump(config_dict, f)
|
||||
model_name_or_path = os.path.join(os.getcwd(), tmp_dir)
|
||||
print("model_name_or_path", model_name_or_path)
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
def get_fd_config(batch_size: int):
|
||||
fd_config = FDConfig(
|
||||
model_config=ModelConfig(
|
||||
{
|
||||
"model": build_config_json(),
|
||||
"max_model_len": 2048,
|
||||
}
|
||||
),
|
||||
parallel_config=ParallelConfig(
|
||||
{
|
||||
"tensor_parallel_size": 1,
|
||||
"expert_parallel_size": 1,
|
||||
"expert_parallel_rank": 0,
|
||||
"data_parallel_size": 1,
|
||||
}
|
||||
),
|
||||
# quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]),
|
||||
scheduler_config=SchedulerConfig({"max_num_seqs": batch_size}),
|
||||
cache_config=CacheConfig({}),
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
load_config=LoadConfig({}),
|
||||
ips="0.0.0.0",
|
||||
)
|
||||
return fd_config
|
||||
|
||||
|
||||
class TestGPUPromptLogprobs(unittest.TestCase):
|
||||
def setup_model_runner(self):
|
||||
"""Helper method to setup GPUModelRunner with different configurations"""
|
||||
@@ -96,7 +154,7 @@ class TestGPUPromptLogprobs(unittest.TestCase):
|
||||
model_runner.ori_vocab_size = cfg.model_config.ori_vocab_size
|
||||
model_runner.share_inputs = {}
|
||||
model_runner.share_inputs["cu_seqlens_q"] = paddle.to_tensor([0, 1, 2, 3], dtype="int32")
|
||||
model_runner.sampler = Sampler()
|
||||
model_runner.sampler = Sampler(get_fd_config(batch_size=1))
|
||||
|
||||
model_runner.model = FakeModel(cfg.model_config.vocab_size, cfg.model_config.hidden_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user