Files
FastDeploy/fastdeploy/model_executor/layers/sample/sampler.py
freeliuzc a7359d1c1d [Cherry-Pick][CI]Support different inferseed in speculate decoding(#5568) (#5597)
* fix mtp entropy drop in RL

* optimize usage and fix unit test

* optimize padding_sampling_params speed(vectorized)
2025-12-17 16:53:47 +08:00

1162 lines
46 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import multiprocessing
import time
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, List, Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.envs import FD_FILL_BITMASK_BATCH
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
min_p_sampling,
speculate_get_target_logits,
speculate_insert_first_token,
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
from fastdeploy.reasoning import ReasoningParser
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
def top_p_normalize_probs_paddle(
probs: paddle.Tensor,
top_ps: paddle.Tensor,
):
probs_idx = probs.argsort(axis=-1, descending=True)
probs_sort = paddle.take_along_axis(probs, probs_idx, axis=-1)
probs_sum = paddle.cumsum(probs_sort, axis=-1)
probs_sort = paddle.where((probs_sum - probs_sort) > top_ps, paddle.zeros_like(probs_sort), probs_sort)
probs_sort.divide_(probs_sort.sum(axis=-1, keepdim=True))
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder):
real_bsz = seq_lens_this_time.shape[0]
repeats = paddle.where(seq_lens_encoder[:real_bsz] == 0, seq_lens_this_time, paddle.ones_like(seq_lens_this_time))
top_p_padding = paddle.repeat_interleave(top_p[:real_bsz], repeats).unsqueeze(1)
top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1)
topp_seed = paddle.repeat_interleave(infer_seed[:real_bsz], repeats).unsqueeze(1)
MAX_INFER_SEED = 9223372036854775806
token_lens = paddle.where(
seq_lens_encoder[:real_bsz] == 0,
seq_lens_this_time,
paddle.ones_like(seq_lens_this_time),
)
batch_start = (paddle.cumsum(token_lens, axis=0) - token_lens.astype("int64")).reshape(-1) # [B]
token_batch_ids = paddle.repeat_interleave(
paddle.arange(token_lens.shape[0], dtype="int64"),
token_lens,
)
token_pos = paddle.arange(topp_seed.shape[0], dtype="int64")
local_pos = token_pos - paddle.gather(batch_start, token_batch_ids)
is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape(-1)
offsets = paddle.where(
is_decoder,
local_pos * 4,
paddle.zeros_like(local_pos),
)
topp_seed[:, 0] = (topp_seed[:, 0] + offsets) % MAX_INFER_SEED
return top_p_padding, top_k_padding, topp_seed
class GuidedDecoding:
"""
processor for guided decoding.
"""
def __init__(self, fd_config: FDConfig):
self.token_bitmask = None
self.max_num_seqs: int = int(
fd_config.scheduler_config.max_num_seqs if fd_config.scheduler_config is not None else 1
)
self.logits_processors: List[Any] = [None] * self.max_num_seqs
self.reasoning_parser = None
self._prefill_done_idxs: List[bool] = [False] * self.max_num_seqs
# for pd
self._tokens_to_acc: List[None | List[int]] = [None] * self.max_num_seqs
self.fill_bitmask_parallel_batch_size: int = FD_FILL_BITMASK_BATCH
max_workers = max(
1,
min(multiprocessing.cpu_count() // 2, int(self.max_num_seqs) / int(self.fill_bitmask_parallel_batch_size)),
)
self.executor_for_fillmask = ThreadPoolExecutor(max_workers=int(max_workers))
self._fillmask_futures: List[Future] = [None] * self.max_num_seqs
self.is_cuda_platform = current_platform.is_cuda()
logger.info(
f"GuidedDecoding max_num_seqs={self.max_num_seqs} fill_bitmask_parallel_batch_size={self.fill_bitmask_parallel_batch_size} is_cuda_platform={self.is_cuda_platform} max_workers={max_workers}"
)
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
self.reasoning_parser = reasoning_parser
def add_logits_processor(
self,
idx: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""add logits processor to SamplerProcessor"""
self._prefill_done_idxs[idx] = False
if future is None:
# normal request without guided_backend
self.logits_processors[idx] = None
return
if len(prefill_tokens) != 0:
# first_token from prefill node
self._prefill_done_idxs[idx] = True
if future.done():
# cached xgrammar
self.logits_processors[idx] = future.result()
for token in prefill_tokens:
self._accept_token(idx, token)
else:
# async
self.logits_processors[idx] = future
self._tokens_to_acc[idx] = prefill_tokens
def should_fill_bitmask(self, idx: int) -> bool:
"""
Determines whether to fill a bitmask for the logits processor at the given index.
Args:
idx (int): The index of the logits processor to check
Returns:
bool: True if the idx request bitmask should be filled
"""
if self.reasoning_parser is not None:
if self.logits_processors[idx].enable_reasoning: # <think> guided
return True
if not self.logits_processors[idx].reasoning_ended:
return False
return True
def reset_processor(self, idx: int):
"""reset idx"""
self._prefill_done_idxs[idx] = False
self.logits_processors[idx] = None
def update_vocab_mask(self, prefill_done_idxs: List[int] = []):
"""update vocab mask. (cpu-heavy operation)"""
for idx in prefill_done_idxs:
if self.logits_processors[idx] is None:
continue
assert not self._prefill_done_idxs[idx]
self._prefill_done_idxs[idx] = True
if isinstance(self.logits_processors[idx], Future):
continue
idxs = []
for idx, processor in enumerate(self.logits_processors):
if processor is None or not self._prefill_done_idxs[idx]:
continue
# skip, join at apply_token_mask
if isinstance(processor, Future):
continue
if processor.is_terminated:
self.reset_processor(idx)
continue
self.accept_tokens_from_prefill_node(idx)
if self.token_bitmask is None:
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
if self.should_fill_bitmask(idx):
idxs.append(idx)
self._async_batch_fill_token_bitmask(idxs)
def batch_fill_token_bitmask(self, batch: List[int]):
"""
Fills the token bitmask for a batch of logits processor indices.
This method is typically called asynchronously via a thread pool executor
to parallelize the bitmask filling operation. It is important that any
shared data structures accessed within this method (such as
`self.token_bitmask` and `self.logits_processors`) are thread-safe or
properly synchronized to avoid race conditions.
Args:
batch (List[int]): List of indices for which to fill the token bitmask.
"""
for idx in batch:
self.logits_processors[idx].fill_token_bitmask(self.token_bitmask, idx)
def _async_batch_fill_token_bitmask(self, idxs: List[int]):
"""launch async fill"""
batch: List[int] = []
for idx in idxs:
batch.append(idx)
if len(batch) == self.fill_bitmask_parallel_batch_size:
promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:])
self._fillmask_futures[idx] = promise
batch = []
if batch:
promise = self.executor_for_fillmask.submit(self.batch_fill_token_bitmask, batch[:])
self._fillmask_futures[batch[-1]] = promise
def join_async_fillmask(self):
"""join all async fill futures"""
for idx, furture in enumerate(self._fillmask_futures):
if furture is not None:
try:
furture.result()
except Exception as e:
logger.error(f"Exception in async fillmask future at idx {idx}: {e}", exc_info=True)
self._fillmask_futures[idx] = None
def accept_tokens_from_prefill_node(self, idx: int):
"""accept prefill token, not future"""
if self._tokens_to_acc[idx] is not None:
# accept token from prefill node first
for token in self._tokens_to_acc[idx]:
self._accept_token(idx, token)
self._tokens_to_acc[idx] = None
def apply_token_mask(self, logits: paddle.Tensor, prefill_done_idxs: List[int] = []):
"""apply token mask to logits"""
indices = []
for idx, processor in enumerate(self.logits_processors):
if processor is None or not self._prefill_done_idxs[idx]:
continue
# compiled done, check idx should fill, fill_token_bitmask done in preprocess
if not isinstance(processor, Future):
if self.should_fill_bitmask(idx):
indices.append(idx)
continue
# is Future, processor async compiled not ready, need join and wait
ts = time.time()
wait = False
if not processor.done():
wait = True
self.logits_processors[idx] = processor.result()
if wait:
logger.debug(f"[{idx} join async compile xgrammar, time_cost:{time.time() - ts}]")
self.accept_tokens_from_prefill_node(idx)
# Possible optimization: Extract 'think' content validation from logits_processors,
# allowing join operations to complete immediately after 'think' terminates.
# Furthermore, the current idx could be skipped, with compilation overhead
# estimated at only a few milliseconds.
# check idx for fill_token_mask
if not self.should_fill_bitmask(idx):
continue
indices.append(idx)
if self.token_bitmask is None:
self.token_bitmask = self.logits_processors[idx].allocate_token_bitmask()
# launch async fill
self._async_batch_fill_token_bitmask([idx])
if len(indices) == 0:
return logits
self.join_async_fillmask()
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
apply_token_mask,
)
return apply_token_mask(logits, self.token_bitmask, indices=indices, is_cuda_platform=self.is_cuda_platform)
def _accept_token(self, idx: int, token: int):
"""accept token"""
if self.reasoning_parser is not None:
if not self.logits_processors[idx].enable_reasoning:
if not self.logits_processors[idx].reasoning_ended:
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
self.logits_processors[idx].reasoning_ended = reasoning_ended
return
if not self.logits_processors[idx].accept_token(token) or self.logits_processors[idx].is_terminated:
self.reset_processor(idx)
def update_output_tokens(self, next_tokens: paddle.Tensor):
"""update output tokens"""
if len(self.logits_processors) == 0:
return
token_ids = next_tokens.numpy().tolist()
for idx, processor in enumerate(self.logits_processors):
if not self._prefill_done_idxs[idx] or processor is None:
continue
if idx >= len(token_ids):
continue
token = token_ids[idx][0]
if token < 0:
self.reset_processor(idx)
continue
logger.debug(f"[{idx}]accept token{token}")
self._accept_token(idx, token)
def pre_process(self, prefill_done_idxs: List[int] = []):
"""pre process before running"""
self.update_vocab_mask(prefill_done_idxs)
class Sampler(nn.Layer):
"""
Sampler for normal generation.
"""
def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprobs"):
""" """
super().__init__()
if (
current_platform.is_cuda()
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
elif current_platform.is_intel_hpu():
self.forward = self.forward_intel_hpu
else:
raise NotImplementedError
self.guided_decoding = GuidedDecoding(fd_config)
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
fd_config is not None
and fd_config.early_stop_config is not None
and fd_config.early_stop_config.enable_early_stop
):
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
self.early_stopper = early_stopper_cls()
self.early_stopper.initialize(fd_config.scheduler_config.max_num_seqs, fd_config.early_stop_config)
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
def apply_logits_processor(
self, ids: int, future: Future[LogitsProcessorBase] = None, prefill_tokens: List[int] = []
):
"""apply logits processor to sampler"""
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
def pre_process(self, prefill_done_idxs: List[int] = []):
"""pre process before running"""
self.guided_decoding.pre_process(prefill_done_idxs)
def post_process(self, next_tokens: paddle.Tensor):
"""post process after running"""
self.guided_decoding.update_output_tokens(next_tokens)
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
) -> paddle.Tensor:
""" """
if sampling_metadata is None:
return F.log_softmax(logits, axis=-1)
last_logits = logits
real_bsz = last_logits.shape[0]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
share_inputs = sampling_metadata.share_inputs
if temp_scaled_logprobs is not None and sampling_metadata.temp_scaled_logprobs_flag:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature))
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_req_mask = None
if (
top_p_normalized_logprobs is not None
and share_inputs is not None
and sampling_metadata.top_p_normalized_logprobs_flag
):
seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz]
seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz]
seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz]
seq_lens_time_sum = seq_lens_this_time + seq_lens_encoder + seq_lens_decoder
real_req_mask = seq_lens_time_sum > 0
top_p_req_mask = paddle.logical_and(top_p_normalized_logprobs[:real_bsz], real_req_mask)
real_req_top_p = sampling_metadata.top_p[:real_bsz]
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
top_p_req_mask = paddle.logical_and(top_p_req_mask, real_req_top_p != 1.0)
if top_p_req_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_req_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_req_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
logprobs: paddle.Tensor,
num_logprobs: int,
token_ids: paddle.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
if len(token_ids.shape) < len(logprobs.shape):
token_ids = token_ids.unsqueeze(-1)
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
if num_logprobs >= 1:
# Find the topK values.
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
indices = token_ids
top_logprobs = token_logprobs
indices = indices.cpu()
top_logprobs = top_logprobs.cpu()
token_ranks = token_ranks.cpu()
return LogprobsTensors(indices, top_logprobs, token_ranks)
def forward_cuda(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
p_done_idxs: List[int] = [],
) -> SamplerOutput:
""" """
logits = self.guided_decoding.apply_token_mask(logits, p_done_idxs)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
for proc in sampling_metadata.logits_processors or []:
logits = proc.apply(logits)
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
sampling_metadata.prompt_lens,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
)
if num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
elif self.logprobs_mode == "processed_logits":
raw_logprobs = logits.clone()
probs = F.softmax(logits)
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
_, next_tokens = top_k_top_p_sampling(
probs,
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.top_k_list,
topp_seed=sampling_metadata.seed,
)
logprobs_tensors = (
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
)
if sampling_metadata.enable_early_stop:
# will set the stop batch in stop_flags
assert sampling_metadata.stop_flags is not None, "need stop_flags for early stop"
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=next_tokens,
logprobs_tensors=logprobs_tensors,
)
return sampler_output
def forward_intel_hpu(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
batch_ids: paddle.Tensor,
max_batch: int,
rank: int,
local_rank: int,
) -> paddle.Tensor:
if logits.dtype != paddle.float32:
logits = paddle.cast(logits, paddle.float32)
from fastdeploy.model_executor.ops.intel_hpu import fused_sampler
_, next_tokens = fused_sampler(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
sampling_metadata.seq_lens_encoder,
sampling_metadata.seq_lens_decoder,
sampling_metadata.step_idx,
sampling_metadata.stop_flags,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
sampling_metadata.top_p,
rank,
local_rank,
)
if next_tokens.shape[0] != max_batch:
dim = next_tokens.shape[-1]
tmp_tokens = paddle.full((max_batch, dim), -1 if local_rank == 0 else 0, dtype=next_tokens.dtype)
tmp_tokens = paddle.scatter(tmp_tokens, batch_ids, next_tokens[: batch_ids.shape[0], :])
return tmp_tokens
return next_tokens
class SpeculativeSampler(nn.Layer):
"""
Sampler for speculative generation.
"""
def __init__(self, fd_config: FDConfig):
""" """
super().__init__()
if current_platform.is_cuda():
self.forward = self.forward_cuda
elif current_platform.is_xpu():
self.forward = self.forward_xpu
else:
raise NotImplementedError
self.logprobs_mode = fd_config.model_config.logprobs_mode
self.speculative_verify_window = fd_config.speculative_config.verify_window
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
pass
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
pass
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
pass
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
pass
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
) -> paddle.Tensor:
"""compute logprobs"""
share_inputs = sampling_metadata.share_inputs
last_logits = logits
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = share_inputs["accept_num"][:real_bsz]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
real_bsz_temp_scaled = (
real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool")
)
temperature = temperature.squeeze(1).repeat_interleave(batch_token_num)
temp_temperature = paddle.where(
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
).unsqueeze(1)
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
real_token_top_p = (
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
)
top_p_normalized_logprobs = (
top_p_normalized_logprobs[:real_bsz]
.astype("int32")
.squeeze(1)
.repeat_interleave(batch_token_num)
.astype("bool")
.unsqueeze(1)
)
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
if top_p_token_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
logprobs: paddle.Tensor,
num_logprobs: int,
token_ids: paddle.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
token_ids = token_ids.unsqueeze(1)
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
if num_logprobs >= 1:
# Find the topK values.
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
def forward_cuda(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> paddle.Tensor:
""" """
from fastdeploy.model_executor.ops.gpu import speculate_verify, top_p_candidates
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["output_padding_offset"],
self.speculative_max_candidate_len,
max_model_len,
)
speculate_verify(
sampled_token_ids,
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
share_inputs["seq_lens_this_time"],
verify_tokens,
verify_scores,
share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["output_cum_offsets"],
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
max_model_len,
self.speculative_verify_window,
True, # enable_topp
(self.speculative_benchmark_mode or reject_all_drafts),
accept_all_drafts,
)
num_logprobs = sampling_metadata.max_num_logprobs
batch_token_num = None
if num_logprobs is not None:
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
share_inputs["seq_lens_this_time"],
).squeeze(1)
share_inputs["batch_token_num"] = batch_token_num
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
target_logits = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
)
speculate_get_target_logits(
target_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["accept_num"],
)
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = target_logits.clone()
logprobs_tensors = None
token_ids = share_inputs["accept_tokens"]
if num_logprobs is not None:
token_ids = paddle.concat(
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)]
)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)
return sampler_output
def forward_xpu(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> paddle.Tensor:
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["output_padding_offset"],
self.speculative_max_candidate_len,
max_model_len,
)
speculate_verify(
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
share_inputs["seq_lens_this_time"],
verify_tokens,
verify_scores,
share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["output_cum_offsets"],
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
max_model_len,
self.speculative_verify_window,
True, # enable_topp
(self.speculative_benchmark_mode or reject_all_drafts),
accept_all_drafts,
)
# TODO(chenhuan09): support return logprobs
token_ids = share_inputs["accept_tokens"]
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=None,
)
return sampler_output
class MTPSampler(nn.Layer):
""" """
def __init__(self, fd_config: FDConfig):
""" """
super().__init__()
if current_platform.is_cuda():
self.forward = self.forward_cuda
elif current_platform.is_xpu():
self.forward = self.forward_xpu
else:
raise NotImplementedError
self.logprobs_mode = fd_config.model_config.logprobs_mode
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
pass
def apply_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""apply logits processor to sampler"""
pass
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
pass
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
pass
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
) -> paddle.Tensor:
"""compute logprobs"""
share_inputs = sampling_metadata.share_inputs
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
last_logits = logits
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
real_bsz_temp_scaled = (
real_bsz_temp_scaled.astype("int32")
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.astype("bool")
)
temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
temp_temperature = paddle.where(
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
).unsqueeze(1)
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
real_token_top_p = (
sampling_metadata.top_p[:real_bsz]
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.unsqueeze(1)
)
top_p_normalized_logprobs = (
top_p_normalized_logprobs[:real_bsz]
.astype("int32")
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.astype("bool")
.unsqueeze(1)
)
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
if top_p_token_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
logprobs: paddle.Tensor,
num_logprobs: int,
token_ids: paddle.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
token_ids = token_ids.unsqueeze(1)
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
if num_logprobs >= 1:
# Find the topK values.
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
def forward_cuda(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
) -> paddle.Tensor:
""" """
num_logprobs = sampling_metadata.max_num_logprobs
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
if num_logprobs is not None and share_inputs["substep"] == 0:
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(
share_inputs["draft_logits"][:real_token_num, :], sampling_metadata
)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = share_inputs["draft_logits"][:real_token_num, :].clone()
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)
next_tokens = paddle.argmax(probs, axis=-1)
token_ids = None
logprobs_tensors = None
if num_logprobs is not None and share_inputs["substep"] == 0:
token_ids = paddle.empty(real_token_num, dtype="int64")
speculate_insert_first_token(
token_ids,
share_inputs["accept_tokens"],
next_tokens,
share_inputs["cu_next_token_offset"],
share_inputs["cu_batch_token_offset"],
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=share_inputs["batch_token_num"][:real_bsz],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)
return next_tokens, sampler_output
def forward_xpu(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
) -> paddle.Tensor:
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)
_, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
)
# TODO(chenhuan09): add support for logprobs
token_ids = None
logprobs_tensors = None
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=None,
cu_batch_token_offset=None,
)
return next_tokens, sampler_output